diff --git a/.appveyor.yml b/.appveyor.yml index 87aee9c974..d90d4ba724 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -43,6 +43,7 @@ build_script: - ps: Push-AppveyorArtifact C:\blis.zip test_script: +# "make checkblas" does not work with shared linking Windows due to inability to override xerbla_ - if [%LIB_TYPE%]==[shared] set "TEST_TARGET=checkblis-fast" - if [%LIB_TYPE%]==[static] set "TEST_TARGET=check" - bash -lc "cd /c/projects/blis && mingw32-make %TEST_TARGET% -j4 V=1" diff --git a/.travis.yml b/.travis.yml index bbae9a7d9f..a61a879fa1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,80 +1,76 @@ language: c sudo: required -dist: trusty +dist: focal +branches: + only: + - master + - dev + - amd matrix: include: - # full testsuite (all tests except for mixed datatype) + # full testsuite (all tests + mixed datatype (gemm_nn only) + salt + SDE + OOT) - os: linux compiler: gcc - env: OOT=0 TEST=1 SDE=0 THR="none" CONF="auto" - # mixed-datatype testsuite (gemm_nn only) - - os: linux - compiler: gcc - env: OOT=0 TEST=MD SDE=0 THR="none" CONF="auto" - # salt testsuite (fast set of operations+parameters) - - os: linux - compiler: gcc - env: OOT=0 TEST=SALT SDE=0 THR="none" CONF="auto" - # test x86_64 ukrs with SDE - - os: linux - compiler: gcc - env: OOT=0 TEST=0 SDE=1 THR="none" CONF="x86_64" + env: OOT=1 TEST=ALL SDE=1 THR="none" CONF="x86_64" \ + PACKAGES="gcc-8 binutils" # openmp build - os: linux compiler: gcc - env: OOT=0 TEST=0 SDE=0 THR="openmp" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="openmp" CONF="auto" \ + PACKAGES="gcc-8 binutils" # pthreads build - os: linux compiler: gcc - env: OOT=0 TEST=0 SDE=0 THR="pthreads" CONF="auto" - # out-of-tree build - - os: linux - compiler: gcc - env: OOT=1 TEST=0 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="pthreads" CONF="auto" \ + PACKAGES="gcc-8 binutils" # clang build - os: linux compiler: clang - env: OOT=0 TEST=0 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="auto" + # There seems to be some difficulty installing 2 Clang toolchains of different versions. + # Use the TravisCI default. + # PACKAGES="clang-8 binutils" # macOS with system compiler (clang) - os: osx compiler: clang - env: OOT=0 TEST=1 SDE=0 THR="none" CONF="auto" + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="auto" # cortexa15 build and fast testsuite (qemu) - os: linux compiler: arm-linux-gnueabihf-gcc env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="cortexa15" \ - PACKAGES="gcc-arm-linux-gnueabihf qemu-system-arm qemu-user" \ + CC=arm-linux-gnueabihf-gcc CXX=arm-linux-gnueabihf-g++ \ + PACKAGES="gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf libc6-dev-armhf-cross qemu-system-arm qemu-user" \ TESTSUITE_WRAPPER="qemu-arm -cpu cortex-a15 -L /usr/arm-linux-gnueabihf/" # cortexa57 build and fast testsuite (qemu) - os: linux compiler: aarch64-linux-gnu-gcc env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="cortexa57" \ - PACKAGES="gcc-aarch64-linux-gnu qemu-system-arm qemu-user" \ + CC=aarch64-linux-gnu-gcc CXX=aarch64-linux-gnu-g++ \ + PACKAGES="gcc-aarch64-linux-gnu g++-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ TESTSUITE_WRAPPER="qemu-aarch64 -L /usr/aarch64-linux-gnu/" + # armsve build and fast testsuite (qemu) + - os: linux + compiler: aarch64-linux-gnu-gcc-10 + env: OOT=0 TEST=FAST SDE=0 THR="none" CONF="armsve" \ + CC=aarch64-linux-gnu-gcc-10 CXX=aarch64-linux-gnu-g++-10 \ + PACKAGES="gcc-10-aarch64-linux-gnu g++-10-aarch64-linux-gnu libc6-dev-arm64-cross qemu-system-arm qemu-user" \ + TESTSUITE_WRAPPER="qemu-aarch64 -cpu max,sve=true,sve512=true -L /usr/aarch64-linux-gnu/" install: -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo rm -f /usr/bin/as; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s /usr/lib/binutils-2.26/bin/as /usr/bin/as; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo rm -f /usr/bin/ld; fi -- if [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo ln -s /usr/lib/binutils-2.26/bin/ld /usr/bin/ld; fi -- if [ "$CC" = "gcc" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then export CC="gcc-6"; fi -- if [ -n "$PACKAGES" ]; then sudo apt-get install -y $PACKAGES; fi -addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - gcc-6 - - binutils-2.26 - - clang +- if [ "$CC" = "gcc" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then export CC="gcc-8"; fi +- if [ -n "$PACKAGES" ] && [ "$TRAVIS_OS_NAME" = "linux" ]; then sudo apt-get install -y $PACKAGES; fi script: - export DIST_PATH=. - pwd - if [ $OOT -eq 1 ]; then export DIST_PATH=`pwd`; mkdir ../oot; cd ../oot; chmod -R a-w $DIST_PATH; fi - pwd -- $DIST_PATH/configure -t $THR CC=$CC $CONF +- $DIST_PATH/configure -p `pwd`/../install -t $THR CC=$CC $CONF - pwd - ls -l - $CC --version - make -j 2 +- make install +- $DIST_PATH/travis/cxx/cxx-test.sh $DIST_PATH $(ls -1 include) +# Qemu SVE is failing sgemmt in some cases. Skip as this issue is not observed on real chip (A64fx). +- if [ "$CONF" = "armsve" ]; then sed -i 's/.*\.*/0/' $DIST_PATH/testsuite/input.operations.fast; fi - if [ "$TEST" != "0" ]; then travis_wait 30 $DIST_PATH/travis/do_testsuite.sh; fi - if [ "$SDE" = "1" ]; then travis_wait 30 $DIST_PATH/travis/do_sde.sh; fi diff --git a/CMakeLists.txt b/CMakeLists.txt index a1f89a6682..787f831452 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,683 +1,1110 @@ -##Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved.## -cmake_minimum_required(VERSION 3.0.0) - -project(AOCL-LibBlis-Win C CXX) - -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") +cmake_minimum_required(VERSION 3.15.0) +if(WIN32) + project(AOCL-LibBlis LANGUAGES C CXX) +else() + project(AOCL-LibBlis LANGUAGES C CXX Fortran) +endif() +# Set the C standard to C99. +set(CMAKE_C_STANDARD 99) +set(CMAKE_C_STANDARD_REQUIRED TRUE) +# Set the C++ standard to C++11. +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED TRUE) -SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library -path") -set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) -set(AOCL_BLIS_ZEN TRUE) -set (PYTHON_EXE "python") +# Enable IDE folders for targets. +set_property(GLOBAL PROPERTY USE_FOLDERS ON) -if ("${AOCL_BLIS_FAMILY}" STREQUAL "") - message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3, zen4 or amdzen") +# Find a python interpreter. +find_package(Python COMPONENTS Interpreter REQUIRED) +if(NOT Python_FOUND) + message(SEND_ERROR "Could not find working python interperter! Cannot continue.") +endif() +# Functionality that prints configuration usage. +option(PRINT_CONFIGURE_HELP "Print CMake Configuration Usage" OFF) +if(PRINT_CONFIGURE_HELP) + execute_process(COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/build/cmake/config_print.py) + return() endif () -if (${AOCL_BLIS_FAMILY} STREQUAL "auto") - set(AUTO_CONFIG_PY "${CMAKE_SOURCE_DIR}/build/auto_config.py") - # Run python script to find the architecture family name - execute_process( - COMMAND ${PYTHON_EXE} ${AUTO_CONFIG_PY} - RESULT_VARIABLE CMD_RESULT - OUTPUT_VARIABLE CMD_OUTPUT - OUTPUT_STRIP_TRAILING_WHITESPACE) - message( STATUS "Auto configuring the family :" ${CMD_OUTPUT}) - set(AOCL_BLIS_FAMILY ${CMD_OUTPUT}) -endif () +if(WIN32) + set(BLIS_CONFIG_FAMILY "auto" CACHE STRING "Set the configuration family for which the BLIS library will be built.") +else() + set(BLIS_CONFIG_FAMILY "" CACHE STRING "Set the configuration family for which the BLIS library will be built.") +endif() +set_property(CACHE BLIS_CONFIG_FAMILY PROPERTY STRINGS "auto" "generic" "zen" "zen2" "zen3" "zen4" "amdzen") +# Throw an error if CMake was configured with a configuration which is not enabled yet. +if(NOT ((BLIS_CONFIG_FAMILY STREQUAL auto) OR + (BLIS_CONFIG_FAMILY STREQUAL generic) OR + (BLIS_CONFIG_FAMILY STREQUAL zen) OR + (BLIS_CONFIG_FAMILY STREQUAL zen2) OR + (BLIS_CONFIG_FAMILY STREQUAL zen3) OR + (BLIS_CONFIG_FAMILY STREQUAL zen4) OR + (BLIS_CONFIG_FAMILY STREQUAL amdzen))) + message(FATAL_ERROR "Configuration for ${BLIS_CONFIG_FAMILY} is not supported. \ + Please re-run cmake and specify one of the following configurations for BLIS_CONFIG_FAMILY: \ + auto, zen, zen2, zen3, zen4, amdzen, generic.") +endif() -if(${AOCL_BLIS_FAMILY} STREQUAL "zen") - add_definitions(-DBLIS_FAMILY_ZEN) - add_definitions(-DBLIS_CONFIG_ZEN) - add_definitions(-DBLIS_KERNELS_ZEN) - add_definitions(-DBLIS_KERNELS_HASWELL) -elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") - add_definitions(-DBLIS_FAMILY_ZEN2) - add_definitions(-DBLIS_CONFIG_ZEN2) - add_definitions(-DBLIS_KERNELS_ZEN2) - add_definitions(-DBLIS_KERNELS_ZEN) - add_definitions(-DBLIS_KERNELS_HASWELL) -elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") - add_definitions(-DBLIS_FAMILY_ZEN3) - add_definitions(-DBLIS_CONFIG_ZEN3) - add_definitions(-DBLIS_KERNELS_ZEN3) - add_definitions(-DBLIS_KERNELS_ZEN2) - add_definitions(-DBLIS_KERNELS_ZEN) - add_definitions(-DBLIS_KERNELS_HASWELL) -elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen4") - add_definitions(-DBLIS_FAMILY_ZEN4) - add_definitions(-DBLIS_CONFIG_ZEN4) - add_definitions(-DBLIS_KERNELS_SKX) - add_definitions(-DBLIS_KERNELS_ZEN4) - add_definitions(-DBLIS_KERNELS_ZEN3) - add_definitions(-DBLIS_KERNELS_ZEN2) - add_definitions(-DBLIS_KERNELS_ZEN) - add_definitions(-DBLIS_KERNELS_HASWELL) -elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") - set(AOCL_BLIS_ZEN FALSE) - add_definitions(-DBLIS_FAMILY_AMDZEN) - add_definitions(-DBLIS_CONFIG_ZEN4) - add_definitions(-DBLIS_CONFIG_ZEN3) - add_definitions(-DBLIS_CONFIG_ZEN2) - add_definitions(-DBLIS_CONFIG_ZEN) - add_definitions(-DBLIS_CONFIG_GENERIC) - add_definitions(-DBLIS_KERNELS_SKX) - add_definitions(-DBLIS_KERNELS_ZEN4) - add_definitions(-DBLIS_KERNELS_ZEN3) - add_definitions(-DBLIS_KERNELS_ZEN2) - add_definitions(-DBLIS_KERNELS_HASWELL) - add_definitions(-DBLIS_KERNELS_ZEN) - add_definitions(-DBLIS_KERNELS_GENERIC) -else () - message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3, zen4 or amdzen") -endif () +# automatic hardware detection +if(BLIS_CONFIG_FAMILY STREQUAL "auto") + message(STATUS "automatic configuration requested") + set(auto_detect_source_files + "${CMAKE_SOURCE_DIR}/build/detect/config/config_detect.c" + "${CMAKE_SOURCE_DIR}/frame/base/bli_arch.c" + "${CMAKE_SOURCE_DIR}/frame/base/bli_cpuid.c" + "${CMAKE_SOURCE_DIR}/frame/base/bli_env.c" + ) + set(frame_include " ${CMAKE_SOURCE_DIR}/frame/include") + set(base_include " ${CMAKE_SOURCE_DIR}/frame/base") + set(thread_include " ${CMAKE_SOURCE_DIR}/frame/thread") + # Try building an executable from one or more source files. + # Build success returns TRUE and build failure returns FALSE in COMPILERESULT. + # If the build succeeds, this runs the executable and stores the exit code in RUNRESULT. + # If the executable was built, but failed to run, then RUNRESULT will be set to FAILED_TO_RUN + # RUN_OUTPUT_VARIABLE Report the output from running the executable in a given variable + try_run(RUNRESULT COMPILERESULT "${CMAKE_BINARY_DIR}/temp" SOURCES ${auto_detect_source_files} + COMPILE_DEFINITIONS -I${frame_include} -I${base_include} -I${thread_include} + -DBLIS_CONFIGURETIME_CPUID -DBLIS_CONFIG_SKX -DBLIS_CONFIG_KNL + -DBLIS_CONFIG_HASWELL -DBLIS_CONFIG_SANDYBRIDGE -DBLIS_CONFIG_PENRYN + -DBLIS_CONFIG_ZEN4 -DBLIS_CONFIG_ZEN3 -DBLIS_CONFIG_ZEN2 -DBLIS_CONFIG_ZEN + -DBLIS_CONFIG_EXCAVATOR -DBLIS_CONFIG_STEAMROLLER -DBLIS_CONFIG_PILEDRIVER + -DBLIS_CONFIG_BULLDOZER -DBLIS_CONFIG_THUNDERX2 -DBLIS_CONFIG_CORTEXA57 + -DBLIS_CONFIG_CORTEXA15 -DBLIS_CONFIG_CORTEXA9 + -D__blis_arch_type_name="BLIS_ARCH_TYPE" -D__blis_model_type_name="BLIS_MODEL_TYPE" + RUN_OUTPUT_VARIABLE HARDWARE_ARCH + ) + string(STRIP "${HARDWARE_ARCH}" HARDWARE_ARCH) + message(STATUS "automatic hardware detection: " ${HARDWARE_ARCH}) + if( NOT(${HARDWARE_ARCH} STREQUAL zen OR + ${HARDWARE_ARCH} STREQUAL zen2 OR + ${HARDWARE_ARCH} STREQUAL zen3 OR + ${HARDWARE_ARCH} STREQUAL zen4) ) + set(BLIS_CONFIG_FAMILY "generic") + message(WARNING "Only AMD zen architectures are supported. \ + Detected ${HARDWARE_ARCH} hardware. Defaulting to generic configuration.") + else() + set(BLIS_CONFIG_FAMILY ${HARDWARE_ARCH}) + endif() + message(STATUS "automatic configuration registered: " ${BLIS_CONFIG_FAMILY}) +endif() -set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) -message("AOCL_BLIS_FAMILY selected:${AOCL_BLIS_FAMILY}") +# Read the registered configuration names and lists into associative arrays. +execute_process( + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/read_registry.py "${BLIS_CONFIG_FAMILY}" "${CMAKE_SOURCE_DIR}" + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CONFIGURATION_STRING + OUTPUT_STRIP_TRAILING_WHITESPACE ) +# Returns the list of elements specified by indices from the list. +message(STATUS "configuration '${BLIS_CONFIG_FAMILY}' is registered.") +list(GET CONFIGURATION_STRING 0 CONFIG_LIST) +list(GET CONFIGURATION_STRING 1 KERNEL_LIST) +list(GET CONFIGURATION_STRING 2 KCONFIG_MAP) +# Removing leading and trailing spaces in the string. +string(STRIP "${CONFIG_LIST}" CONFIG_LIST) +string(STRIP "${KERNEL_LIST}" KERNEL_LIST) +string(STRIP "${KCONFIG_MAP}" KCONFIG_MAP) +# Convert from string to list(list is a ";"-separated string) +message(STATUS "${BLIS_CONFIG_FAMILY} is defined as having the following sub-configurations:") +message(" ${CONFIG_LIST} ") +string(REPLACE " " ";" CONFIG_LIST ${CONFIG_LIST}) +message(STATUS "which collectively require the following kernels:") +message(" ${KERNEL_LIST} ") +string(REPLACE " " ";" KERNEL_LIST ${KERNEL_LIST}) +message(STATUS "that has kernel:config pairs:") +message(" ${KCONFIG_MAP} ") +string(REPLACE " " ";" KCONFIG_MAP ${KCONFIG_MAP}) +# Create a #define for the configuration family (config_name). +string(TOUPPER ${BLIS_CONFIG_FAMILY} UCONF) +set(CONFIG_NAME_DEFINE "#define BLIS_FAMILY_${UCONF}\n") +#create a AOCL specific #define +#This macro is enabled only for zen family configurations. +#This enables us to use different cache block sizes for TRSM instead of common level-3 block sizes. +if(BLIS_CONFIG_FAMILY MATCHES "zen|amd64|x86_64") + set(ENABLE_AOCL_ZEN ON) + set(ENABLE_AOCL_ZEN_01 1) +else() + set(ENABLE_AOCL_ZEN OFF) + set(ENABLE_AOCL_ZEN_01 0) +endif() +# Create a list of #defines, one for each configuration in config_list. +set(CONFIG_LIST_DEFINES "") +foreach(CONF ${CONFIG_LIST}) + string(TOUPPER ${CONF} UCONF) + set(CONFIG_LIST_DEFINES "${CONFIG_LIST_DEFINES}#define BLIS_CONFIG_${UCONF}\n") +endforeach() +# Create a list of #defines, one for each kernel set in kernel_list. +set(KERNEL_LIST_DEFINES "") +foreach(KERN ${KERNEL_LIST}) + string(TOUPPER ${KERN} UCONF) + set(KERNEL_LIST_DEFINES "${KERNEL_LIST_DEFINES}#define BLIS_KERNELS_${UCONF}\n") +endforeach() -option(BUILD_SHARED_LIBS "Build shared library" ON) -option(ENABLE_VERBOSE "Enable VERBOSE mode for build" OFF) -option(ENABLE_MULTITHREADING "Enable Multi threading" OFF) -option(ENABLE_OPENMP "Enable Openmp mode" OFF) -option(ENABLE_JRIR_SLAB "Request slab thread in jr and ir loops" ON) -option(ENABLE_JRIR_RR "Request round robin thread in jr and ir loops" OFF) +#------------------------------------ +# Option Setting +#------------------------------------ +# Options that are specific to Windows. +if(WIN32) + option(ENABLE_NO_UNDERSCORE_API "Export APIs without underscore." OFF) + option(ENABLE_UPPERCASE_API "Export APIs with uppercase." OFF) + # Setting path to OpenMP runtime. + set(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library path") +endif() +# Debug & Release flags option setting is only available for Linux. On Windows the default flags are used. +if(NOT WIN32) + set(ENABLE_DEBUG "off" CACHE STRING "Enable debugging symbols in the library.") + set_property(CACHE ENABLE_DEBUG PROPERTY STRINGS "off" "noopt" "opt") + if( NOT ((ENABLE_DEBUG STREQUAL "off") OR (ENABLE_DEBUG STREQUAL "noopt") OR (ENABLE_DEBUG STREQUAL "opt")) ) + message(FATAL_ERROR "ENABLE_DEBUG option '${ENABLE_DEBUG}' is not supported. Please use one of the following options \ + during CMake invokation: off, noopt, opt") + endif() + # Check if user provided CMAKE_BUILD_TYPE. If that's the case, map it to the internal ENABLE_DEBUG type + # and clean cache from CMAKE_BUILD_TYPE. We do this because CMake will add some flags depending on the + # the build type and on Linux we want to have more control over what flags are being used. + if(CMAKE_BUILD_TYPE) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(ENABLE_DEBUG "noopt") + elseif(CMAKE_BUILD_TYPE STREQUAL "Release") + set(ENABLE_DEBUG "off") + elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + set(ENABLE_DEBUG "opt") + else() + message(FATAL_ERROR "Configured CMake with incompatible CMAKE_BUILD_TYPE. Only Debug, RelWithDebInfo and Release are supported. \ + This is due to matching this flag to BLIS internal options corresponding to ENABLE_DEBUG: off, noopt, opt.") + endif() + message(WARNING "When CMAKE_BUILD_TYPE is used, BLIS-specific variable ENABLE_DEBUG gets overwritten accordingly.") + set(CMAKE_BUILD_TYPE "") + endif() +endif() +# Build shared libraries by default +option(BUILD_SHARED_LIBS "Build shared libraries (.dll/.so) instead of static ones (.lib/.a)" ON) +option(ENABLE_SYSTEM "Check if we are building with or without operating system support" ON) +set(ENABLE_THREADING "no" CACHE STRING "the threading flag") +if(WIN32) + set_property(CACHE ENABLE_THREADING PROPERTY STRINGS "openmp" "no") + if( NOT ((ENABLE_THREADING STREQUAL "openmp") OR (ENABLE_THREADING STREQUAL "no")) ) + message(FATAL_ERROR "ENABLE_THREADING option '${ENABLE_THREADING}' is not supported. Please use one of the following options \ + during CMake invokation: openmp, no") + endif() +else() + set_property(CACHE ENABLE_THREADING PROPERTY STRINGS "openmp" "pthreads" "no") + if( NOT ((ENABLE_THREADING STREQUAL "openmp") OR (ENABLE_THREADING STREQUAL "pthreads") OR (ENABLE_THREADING STREQUAL "no")) ) + message(FATAL_ERROR "ENABLE_THREADING option '${ENABLE_THREADING}' is not supported. Please use one of the following options \ + during CMake invokation: openmp, pthreads, no") + endif() +endif() +set(THREAD_PART_JRIR "slab" CACHE STRING "The method of assigning micropanels to threads in the JR and JR loops.") +set_property(CACHE THREAD_PART_JRIR PROPERTY STRINGS "slab" "rr") +if( NOT ((THREAD_PART_JRIR STREQUAL "slab") OR (THREAD_PART_JRIR STREQUAL "rr")) ) + message(FATAL_ERROR "THREAD_PART_JRIR option '${THREAD_PART_JRIR}' is not supported. Please use one of the following options \ + during CMake invokation: slab, rr") +endif() +# Export symbols only for Linux. +if(NOT WIN32) + set(EXPORT_SHARED "public" CACHE STRING "Specify the subset of library symbols that are exported within a shared library.") + set_property(CACHE EXPORT_SHARED PROPERTY STRINGS "public" "all") + if( NOT ((EXPORT_SHARED STREQUAL "public") OR (EXPORT_SHARED STREQUAL "all")) ) + message(FATAL_ERROR "EXPORT_SHARED option '${EXPORT_SHARED}' is not supported. Please use one of the following options \ + during CMake invokation: public, all") + endif() +endif() option(ENABLE_PBA_POOLS "Internal memory pools for packing blocks" ON) option(ENABLE_SBA_POOLS "Internal memory pools for small blocks" ON) option(ENABLE_MEM_TRACING "Memory tracing output" OFF) +set(INT_SIZE "auto" CACHE STRING "BLIS API integer size") +set_property(CACHE INT_SIZE PROPERTY STRINGS "auto" "32" "64") +if( NOT ((INT_SIZE STREQUAL "auto") OR (INT_SIZE STREQUAL "32") OR (INT_SIZE STREQUAL "64")) ) + message(FATAL_ERROR "INT_SIZE option '${INT_SIZE}' is not supported. Please use one of the following options \ + during CMake invokation: auto, 32, 64") +endif() +set(BLAS_INT_SIZE "32" CACHE STRING "BLAS/CBLAS API integer size") +set_property(CACHE BLAS_INT_SIZE PROPERTY STRINGS "auto" "32" "64") +if( NOT ((BLAS_INT_SIZE STREQUAL "auto") OR (BLAS_INT_SIZE STREQUAL "32") OR (BLAS_INT_SIZE STREQUAL "64")) ) + message(FATAL_ERROR "BLAS_INT_SIZE option '${BLAS_INT_SIZE}' is not supported. Please use one of the following options \ + during CMake invokation: auto, 32, 64") +endif() option(ENABLE_BLAS "BLAS compatiblity layer" ON) -option(ENABLE_CBLAS "CBLAS compatiblity layer" ON) -option(ENABLE_MIXED_DT "Mixed datatype" ON) +option(ENABLE_CBLAS "CBLAS compatiblity layer" OFF) +option(ENABLE_MIXED_DT "Mixed datatype support" ON) option(ENABLE_MIXED_DT_EXTRA_MEM "Mixed datatype optimization requiring extra memory" ON) option(ENABLE_SUP_HANDLING "Small matrix handling" ON) -option(ENABLE_MEMKIND "libmemkind for manage memory pools" OFF) -option(ENABLE_PRAGMA_OMP_SIMD "pragma openmp simd" ON) -option(ENABLE_SANDBOX "Sandbox implementation for gemm" OFF) -option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) -option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) -option(ENABLE_BLASTEST "Enable the blastest" OFF) -option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) -option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) -option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) -option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) -option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) -option(DISABLE_BLIS_ARCH_TYPE "Disable BLIS_ARCH_TYPE and BLIS_MODEL_TYPE functionality" OFF) -option(RENAME_BLIS_ARCH_TYPE "Rename BLIS_ARCH_TYPE env var renamed to supplied value" BLIS_ARCH_TYPE) -option(RENAME_BLIS_MODEL_TYPE "Rename BLIS_MODEL_TYPE env var renamed to supplied value" BLIS_MODEL_TYPE) - -if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") - set(REF_KERNEL_MIRRORING_PY "${CMAKE_SOURCE_DIR}/build/blis_ref_kernel_mirror.py") - message("ref_kernel mirroring for fat binary") - # Run python script to find the architecture family name - execute_process( - COMMAND ${PYTHON_EXE} ${REF_KERNEL_MIRRORING_PY} ${CMAKE_BINARY_DIR} - RESULT_VARIABLE CMD_RESULT - OUTPUT_VARIABLE CMD_OUTPUT - OUTPUT_STRIP_TRAILING_WHITESPACE) - message( STATUS "Ref Kernel Mirroring :" ${CMD_OUTPUT}) -endif() -if(ENABLE_NO_UNDERSCORE_API) - add_definitions(-DBLIS_ENABLE_NO_UNDERSCORE_API) -endif() - -if(ENABLE_COMPLEX_RETURN_INTEL) - set(BLIS_ENABLE_COMPLEX_RETURN_INTEL TRUE) +if(WIN32) + set(ENABLE_MEMKIND "no" CACHE STRING "libmemkind for manage memory pools") + set_property(CACHE ENABLE_MEMKIND PROPERTY STRINGS "no") + if( NOT (ENABLE_MEMKIND STREQUAL "no")) + message(FATAL_ERROR "ENABLE_MEMKIND option is not supported on Windows platforms.") + endif() else() - set(BLIS_DISABLE_COMPLEX_RETURN_INTEL TRUE) + set(ENABLE_MEMKIND "auto" CACHE STRING "libmemkind for manage memory pools") + set_property(CACHE ENABLE_MEMKIND PROPERTY STRINGS "auto" "yes" "no") + if( NOT ((ENABLE_MEMKIND STREQUAL "auto") OR (ENABLE_MEMKIND STREQUAL "yes") OR (ENABLE_MEMKIND STREQUAL "no")) ) + message(FATAL_ERROR "ENABLE_MEMKIND option '${ENABLE_MEMKIND}' is not supported. Please use one of the following options \ + during CMake invokation: auto, yes, no") + endif() +endif() +option(ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) +option(ENABLE_AOCL_DYNAMIC "Dynamic selection of number of threads" ON) +set(FORCE_VERSION "no" CACHE STRING "Force configure to use an arbitrary version string") +if(WIN32) + set(COMPLEX_RETURN "gnu" CACHE STRING "The method used for returning complex numbers") + set_property(CACHE COMPLEX_RETURN PROPERTY STRINGS "gnu" "intel") + if( NOT ((COMPLEX_RETURN STREQUAL "gnu") OR (COMPLEX_RETURN STREQUAL "intel")) ) + message(FATAL_ERROR "COMPLEX_RETURN option '${COMPLEX_RETURN}' is not supported. Please use one of the following options \ + during CMake invokation: gnu, intel") + endif() +else() + set(COMPLEX_RETURN "default" CACHE STRING "The method used for returning complex numbers") + set_property(CACHE COMPLEX_RETURN PROPERTY STRINGS "default" "gnu" "intel") + if( NOT ((COMPLEX_RETURN STREQUAL "default") OR (COMPLEX_RETURN STREQUAL "gnu") OR (COMPLEX_RETURN STREQUAL "intel")) ) + message(FATAL_ERROR "COMPLEX_RETURN option '${COMPLEX_RETURN}' is not supported. Please use one of the following options \ + during CMake invokation: default, gnu, intel") + endif() +endif() +# If the CONFIG_LIST does not already contain the CONFIG_NAME (i.e., +# if CONFIG_NAME is an umbrella family), default is to enable BLIS_ARCH_TYPE functionality, +# otherwise default is to disable BLIS_ARCH_TYPE functionality. +list(FIND CONFIG_LIST ${BLIS_CONFIG_FAMILY} IS_UMBRELLA) +if(${IS_UMBRELLA} STREQUAL "-1") + option(DISABLE_BLIS_ARCH_TYPE "Disable AOCL_ENABLE_INSTRUCTIONS, BLIS_ARCH_TYPE and BLIS_MODEL_TYPE functionality" OFF) +else() + option(DISABLE_BLIS_ARCH_TYPE "Disable AOCL_ENABLE_INSTRUCTIONS, BLIS_ARCH_TYPE and BLIS_MODEL_TYPE functionality" ON) +endif() +set(RENAME_BLIS_ARCH_TYPE "BLIS_ARCH_TYPE" CACHE STRING "BLIS_ARCH_TYPE env var renamed to supplied value") +set(RENAME_BLIS_MODEL_TYPE "BLIS_MODEL_TYPE" CACHE STRING "BLIS_MODEL_TYPE env var renamed to supplied value") +if(NOT WIN32) + set(ENABLE_ADDON "" CACHE STRING "Configure with specific addons using a ';'-separated list") endif() +set(ENABLE_SANDBOX "" CACHE STRING "Enable a separate sandbox implementation of gemm.") +# Do not let ENABLE_SANDBOX appear on cmake-gui since the functionality is not yet implemented. +mark_as_advanced(ENABLE_SANDBOX) -if(ENABLE_UPPERCASE_API) - add_definitions(-DBLIS_ENABLE_UPPERCASE_API) +#------------------------------------ +# Check memkind +#------------------------------------ +# Using libmemkind is not a valid option on Windows. Check only on Linux platforms. +if(NOT WIN32) + # In order to determine the default behavior of the --with[out]-memkind + # option, we try to detect whether libmemkind is available. If it is, + # the default implied option will be --with-memkind; otherwise, will be + # --without-memkind. + try_compile(HAS_MEMKIND "${CMAKE_BINARY_DIR}/temp" SOURCES "${CMAKE_SOURCE_DIR}/build/detect/memkind/libmemkind_detect.c" + LINK_OPTIONS + "-lmemkind" + ) endif() +#------------------------------------ +# Check #pragma omp simd +#------------------------------------ +if(ENABLE_THREADING STREQUAL "openmp") + # Try to determine whether the chosen compiler supports #pragma omp simd. + try_compile(PRAGMA_OMP_SIMD "${CMAKE_BINARY_DIR}/temp" SOURCES "${CMAKE_SOURCE_DIR}/build/detect/omp_simd/omp_simd_detect.c" + CMAKE_FLAGS + "-O3 -march=native -fopenmp-simd" + C_STANDARD 99 + ) +endif() +#------------------------------------ +# Acquire the BLIS version +#------------------------------------ +# Set the VERSION variable to the default value in the 'version' file. +file(STRINGS ${CMAKE_SOURCE_DIR}/version VERSION) +# Get timestamp. +string(TIMESTAMP BUILD_DATE "%Y%m%d") +# Update using the timestamp. +set(VERSION_STRING "AOCL-BLIS ${VERSION} Build ${BUILD_DATE}") +# Initial message. +message(STATUS "Starting configuration of BLIS ${VERSION_STRING}.") +# Check if the user requested a custom version string. +if(FORCE_VERSION STREQUAL "no") + message(" Configuring with official version string.") +else() + set(VERSION_STRING "${FORCE_VERSION}") + message(" Configuring with custom version string: ${VERSION_STRING}") +endif() +# Set the shared library (.so) version file. +file(STRINGS ${CMAKE_SOURCE_DIR}/so_version SO_VERSION) +# The first line of the 'so_version' file contains the .so major version. +list(GET SO_VERSION 0 SO_VERSION_MAJOR) +# The second line contains the minor and build .so version numbers +# (separated by a '.'). +list(GET SO_VERSION 1 SO_VERSION_MINOR) + +#------------------------------------ +# Printing Options +#------------------------------------ +include(CMakePrintHelpers) +message(STATUS "Printing CMake Configuration Options...") +cmake_print_variables(ENABLE_DEBUG) +# Initialize debug type, using the corresponding cache variable. +set(DEBUG_TYPE ${ENABLE_DEBUG}) +if(ENABLE_DEBUG STREQUAL "off") + message(" Debug symbols disabled.") +elseif(ENABLE_DEBUG STREQUAL "opt") + message(" Enabling debug symbols with optimizations.") +else() #ENABLE_DEBUG=noopt + message(" Enabling debug symbols; optimizations disabled.") +endif() +cmake_print_variables(BUILD_SHARED_LIBS) +if(BUILD_SHARED_LIBS) + message(" Building BLIS as a shared library.") + set(ENABLE_SHARED_01 1) +else() + message(" Building BLIS as a static library.") + set(ENABLE_SHARED_01 0) +endif() +if(NOT WIN32) +cmake_print_variables(EXPORT_SHARED) + if(EXPORT_SHARED STREQUAL "all") + if(BUILD_SHARED_LIBS) + message(" Exporting all symbols within shared library.") + else() + message(" Ignoring request to export all symbols within shared library.") + endif() + else() + if(BUILD_SHARED_LIBS) + message(" Exporting only public symbols within shared library.") + endif() + endif() +endif() +cmake_print_variables(ENABLE_SYSTEM) +if(ENABLE_SYSTEM) + message(" Enabling operating system support.") + set(ENABLE_SYSTEM_01 1) + if(NOT WIN32) + set(LIBPTHREAD "-lpthread") + endif() +else() + message(" Disabling operating system support.") + message(" WARNING: all threading will be disabled!") + set(ENABLE_THREADING "off") + set(ENABLE_SYSTEM_01 0) +endif() +# Check the threading model flag and standardize its value, if needed. +cmake_print_variables(ENABLE_THREADING) +set(ENABLE_OPENMP "no") +set(ENABLE_OPENMP_01 0) +set(ENABLE_PTHREADS "no") +set(ENABLE_PTHREADS_01 0) +if(ENABLE_THREADING STREQUAL "openmp") + message(" Using OpenMP for threading.") + set(ENABLE_OPENMP "yes") + set(ENABLE_OPENMP_01 1) + find_package(OpenMP) + if(NOT OPENMP_FOUND) + message(FATAL_ERROR "Openmp Not Found") + endif() +elseif(ENABLE_THREADING STREQUAL "pthreads") + message(" Using POSIX threads for threading.") + set(ENABLE_PTHREADS "yes") + set(ENABLE_PTHREADS_01 1) +else() + message(" Threading is disabled.") +endif() +# Check the method of assigning micropanels to threads in the JR and IR +# loops. +cmake_print_variables(THREAD_PART_JRIR) +if(THREAD_PART_JRIR STREQUAL "slab") + message(" Requesting slab threading in jr and ir loops.") + set(ENABLE_JRIR_SLAB_01 1) + set(ENABLE_JRIR_RR_01 0) +else() + message(" Requesting round-robin threading in jr and ir loops.") + set(ENABLE_JRIR_SLAB_01 0) + set(ENABLE_JRIR_RR_01 1) +endif() +# Convert 'yes' and 'no' flags to booleans. +cmake_print_variables(ENABLE_PBA_POOLS) +if(ENABLE_PBA_POOLS) + message(" Internal memory pools for packing blocks are enabled.") + set(ENABLE_PBA_POOLS_01 1) +else() + message(" Internal memory pools for packing blocks are disabled.") + set(ENABLE_PBA_POOLS_01 0) +endif() +cmake_print_variables(ENABLE_SBA_POOLS) +if(ENABLE_SBA_POOLS) + message(" Internal memory pools for small blocks are enabled.") + set(ENABLE_SBA_POOLS_01 1) +else() + message(" Internal memory pools for small blocks are disabled.") + set(ENABLE_SBA_POOLS_01 0) +endif() +cmake_print_variables(ENABLE_MEM_TRACING) +if(ENABLE_MEM_TRACING) + message(" Memory tracing output is enabled.") + set(ENABLE_MEM_TRACING_01 1) +else() + message(" Memory tracing output is disabled.") + set(ENABLE_MEM_TRACING_01 0) +endif() +cmake_print_variables(ENABLE_MEMKIND) +if(HAS_MEMKIND) + if(ENABLE_MEMKIND STREQUAL "auto") + # If no explicit option was given for libmemkind one way or the other, + # we use the value returned previously by has_libmemkind(), in this + # case "yes", to determine the default. + message(" libmemkind found; default is to enable use.") + set(ENABLE_MEMKIND "yes") + set(ENABLE_MEMKIND_01 1) + else() + if(ENABLE_MEMKIND STREQUAL "yes") + message(" Received explicit request to enable libmemkind.") + set(ENABLE_MEMKIND_01 1) + else() + message(" Received explicit request to disable libmemkind.") + set(ENABLE_MEMKIND "no") + set(ENABLE_MEMKIND_01 0) + endif() + endif() +else() + if(WIN32) + message(" libmemkind option is not supported on Windows.") + else() + message(" libmemkind not found; disabling.") + if(ENABLE_MEMKIND STREQUAL "yes") + message(WARNING " Cannot honor explicit request to enable libmemkind.") + endif() + endif() + set(ENABLE_MEMKIND "no") + set(ENABLE_MEMKIND_01 0) +endif() +cmake_print_variables(PRAGMA_OMP_SIMD) +if(PRAGMA_OMP_SIMD) + message(" Compiler appears to support #pragma omp simd.") + set(ENABLE_PRAGMA_OMP_SIMD_01 1) +else() + message(" Compiler appears to not support #pragma omp simd.") + set(ENABLE_PRAGMA_OMP_SIMD_01 0) +endif() +cmake_print_variables(ENABLE_CBLAS) +if(ENABLE_CBLAS) + message(" The CBLAS compatibility layer is enabled.") + set(ENABLE_CBLAS_01 1) + # Force BLAS layer when CBLAS is enabled + set(ENABLE_BLAS ON) +else() + message(" The CBLAS compatibility layer is disabled.") + set(ENABLE_CBLAS_01 0) +endif() +cmake_print_variables(ENABLE_BLAS) +if(ENABLE_BLAS) + message(" The BLAS compatibility layer is enabled.") + set(ENABLE_BLAS_01 1) +else() + message(" The BLAS compatibility layer is disabled.") + set(ENABLE_BLAS_01 0) +endif() +cmake_print_variables(ENABLE_MIXED_DT) +if(ENABLE_MIXED_DT) + message(" Mixed datatype support is enabled.") + cmake_print_variables(ENABLE_MIXED_DT_EXTRA_MEM) + if(ENABLE_MIXED_DT_EXTRA_MEM) + message(" Mixed datatype optimizations requiring extra memory are enabled.") + set(ENABLE_MIXED_DT_EXTRA_MEM_01 1) + else() + message(" Mixed datatype optimizations requiring extra memory are disabled.") + set(ENABLE_MIXED_DT_EXTRA_MEM_01 0) + endif() + set(ENABLE_MIXED_DT_01 1) +else() + message(" Mixed datatype support is disabled.") + set(ENABLE_MIXED_DT_EXTRA_MEM_01 0) + set(ENABLE_MIXED_DT_01 0) +endif() +cmake_print_variables(ENABLE_SUP_HANDLING) +if(ENABLE_SUP_HANDLING) + message(" Small matrix handling is enabled.") + set(ENABLE_SUP_HANDLING_01 1) +else() + message(" Small matrix handling is disabled.") + set(ENABLE_SUP_HANDLING_01 0) +endif() +cmake_print_variables(ENABLE_TRSM_PREINVERSION) +if(ENABLE_TRSM_PREINVERSION) + message(" trsm diagonal element pre-inversion is enabled.") + set(ENABLE_TRSM_PREINVERSION_01 1) +else() + message(" trsm diagonal element pre-inversion is disabled.") + set(ENABLE_TRSM_PREINVERSION_01 0) +endif() +# Check aocl dynamic threading configuration and enable it only if +# multi-threading is enabled +cmake_print_variables(ENABLE_AOCL_DYNAMIC) if(ENABLE_AOCL_DYNAMIC) - set(AOCL_DYNAMIC TRUE) + if( NOT(ENABLE_THREADING STREQUAL "no")) + message(" Dynamic selection of number of threads is enabled.") + set(ENABLE_AOCL_DYNAMIC_01 1) + else() + message(" Dynamic threading is disabled as multithreading is disabled.") + set(ENABLE_AOCL_DYNAMIC OFF) + set(ENABLE_AOCL_DYNAMIC_01 0) + endif() +else() + message(" Dynamic selection of number of threads is disabled.") + set(ENABLE_AOCL_DYNAMIC_01 0) +endif() +# Report integer sizes. +cmake_print_variables(INT_SIZE) +set(INT_TYPE_SIZE ${INT_SIZE}) +if(INT_TYPE_SIZE STREQUAL "32") + message(" The BLIS API integer size is 32-bit.") +elseif(INT_TYPE_SIZE STREQUAL "64") + message(" The BLIS API integer size is 64-bit.") +else() + set(INT_TYPE_SIZE "0") + message(" The BLIS API integer size is automatically determined.") +endif() +cmake_print_variables(BLAS_INT_SIZE) +set(BLAS_INT_TYPE_SIZE ${BLAS_INT_SIZE}) +if(BLAS_INT_TYPE_SIZE STREQUAL "32") + message(" The BLAS/CBLAS API integer size is 32-bit.") +elseif(BLAS_INT_TYPE_SIZE STREQUAL "64") + message(" The BLAS/CBLAS API integer size is 64-bit.") +else() + set(BLAS_INT_TYPE_SIZE "0") + message(" The BLAS/CBLAS API integer size is automatically determined.") +endif() +# Disallow the simultaneous use of 64-bit integers in the BLAS and +# 32-bit integers in BLIS. +if((INT_TYPE_SIZE STREQUAL "32") AND (BLAS_INT_TYPE_SIZE STREQUAL "64")) + message(FATAL_ERROR "INT_TYPE_SIZE=${INT_TYPE_SIZE} and BLAS_INT_TYPE_SIZE=${BLAS_INT_TYPE_SIZE}. \ + To avoid the possibility of truncation, we do not allow use of 64-bit integers in the BLAS API with 32-bit integers in BLIS. \ + Please use a different configuration of integers.") +endif() +if(NOT WIN32) + cmake_print_variables(ENABLE_ADDON) + if(ENABLE_ADDON STREQUAL "") + message(" Configuring with no addons.") + set(ENABLE_ADDONS_01 0) + else() + # Remove duplicates in the addon list, if they exist. + list(REMOVE_DUPLICATES ENABLE_ADDON) + message(" Configuring with addons:") + foreach(ADDON ${ENABLE_ADDON}) + message(" ${ADDON}") + if(NOT (EXISTS ${CMAKE_SOURCE_DIR}/addon/${ADDON})) + message(FATAL_ERROR "Requested addon sub-directory does not exist! Cannot continue. \ + *** Please verify addon existence and name.") + endif() + endforeach() + set(ENABLE_ADDONS_01 1) + endif() +endif() +cmake_print_variables(ENABLE_SANDBOX) +if(ENABLE_SANDBOX STREQUAL "") + message(" Configuring for conventional gemm implementation.") + set(ENABLE_SANDBOX_01 0) +else() + message(" Configuring with alternate gemm implementation: ${ENABLE_SANDBOX}.") + message(FATAL_ERROR "Sandbox functionality is not yet integrated in CMake build system.") + set(ENABLE_SANDBOX_01 1) +endif() +# Check the method used for returning complex numbers. Only for Linux. +if(NOT WIN32) + if(COMPLEX_RETURN STREQUAL "default") + if("${CMAKE_Fortran_COMPILER_ID}" MATCHES "Intel") + set(COMPLEX_RETURN "intel") + else() + set(COMPLEX_RETURN "gnu") + endif() + endif() endif() +cmake_print_variables(COMPLEX_RETURN) +if(COMPLEX_RETURN STREQUAL "gnu") + message(" Configuring with gnu complex return type.") + set(COMPLEX_RETURN_INTEL_01 0) +else() + message(" Configuring with intel complex return type.") + set(COMPLEX_RETURN_INTEL_01 1) +endif() +cmake_print_variables(DISABLE_BLIS_ARCH_TYPE) +if(DISABLE_BLIS_ARCH_TYPE) + message(" User selection of code path using AOCL_ENABLE_INSTRUCTIONS, BLIS_ARCH_TYPE and") + message(" BLIS_MODEL_TYPE env vars is disabled.") + set(DISABLE_BLIS_ARCH_TYPE_01 1) +else() + set(DISABLE_BLIS_ARCH_TYPE_01 0) +endif() +cmake_print_variables(RENAME_BLIS_ARCH_TYPE) +if(NOT(RENAME_BLIS_ARCH_TYPE STREQUAL "BLIS_ARCH_TYPE")) + message(" configuring with BLIS_ARCH_TYPE env var renamed to ${RENAME_BLIS_ARCH_TYPE}") +endif() +cmake_print_variables(RENAME_BLIS_MODEL_TYPE) +if(NOT(RENAME_BLIS_MODEL_TYPE STREQUAL "BLIS_MODEL_TYPE")) + message(" configuring with BLIS_MODEL_TYPE env var renamed to ${RENAME_BLIS_MODEL_TYPE}") +endif() +if(WIN32) + cmake_print_variables(ENABLE_NO_UNDERSCORE_API) + if(ENABLE_NO_UNDERSCORE_API) + message(" Export APIs without underscore.") + else() + message(" Export APIs with underscore.") + endif() + cmake_print_variables(ENABLE_UPPERCASE_API) + if(ENABLE_UPPERCASE_API) + message(" Export APIs with uppercase.") + else() + message(" Export APIs with lowercase.") + endif() +endif() + +# Initialize threading model, using the corresponding cache variable. +set(THREADING_MODEL ${ENABLE_THREADING}) -if (BUILD_SHARED_LIBS) - set(BLIS_ENABLE_SHARED TRUE) - if(ENABLE_BLASTEST) - add_definitions(-DAOCL_SUPPORT_BLASTEST_FOR_SHARED) + +#-------------------------------------------- +# Instantiate bli_config.h file from template +#-------------------------------------------- +# Begin substituting information into the build/cmake/bli_config.h.in file, outputting +# to bli_config.h and store it in build directory of the current project. +configure_file(build/cmake/bli_config.h.in ${PROJECT_BINARY_DIR}/bli_config.h) + +#-------------------------------------------- +# Instantiate bli_addon.h file from template +#-------------------------------------------- +# Create a list of #includes, one for each addon in addon_list. +set(ADDON_LIST_INCLUDES "") +foreach(ADDON ${ENABLE_ADDON}) + if(ADDON STREQUAL "aocl_gemm") + if(("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 11.0.0)) + message(FATAL_ERROR "aocl_gemm addon requires a gcc version 11.0.0 or higher.") + endif() endif() -endif () + set(ADDON_HEADER "\"${ADDON}.h\"") + set(ADDON_LIST_INCLUDES "${ADDON_LIST_INCLUDES}#include ${ADDON_HEADER}\n") +endforeach() +# Begin substituting information into the bli_addon.h.in file, outputting +# to bli_addon.h and store it in build directory of the current project. +configure_file(build/cmake/bli_addon.h.in ${PROJECT_BINARY_DIR}/bli_addon.h) -# Enable LP64/ILP64 -if (BLIS_ENABLE_ILP64) - set(BLIS_BLAS_INT_TYPE_SIZE TRUE) - set (BLAS_INT_TYPE_SIZE "64") - add_definitions(-DF2C_ENABLE_ILP64) -else () - set(BLIS_BLAS_INT_TYPE_SIZE TRUE) - set (BLAS_INT_TYPE_SIZE "32") -endif () +#-------------------------------------------- +# Collect directory paths for blis.h +#-------------------------------------------- +# Variable ALL_HEADER_PATHS_LIST is equivalent to ALL_H99_DIRPATHS in Make system. +# Practically, we collect the required directory paths into a list, which we +# append as we add the corresponding subdirectories. This variable will be +# transformed into a string and will be used to generate the flatten blis.h header. +set(ALL_HEADER_PATHS_LIST "") +# Track files to set dependencies for blis.h. +set(ALL_HEADER_FILES_LIST "") -if (ENABLE_TRSM_PREINVERSION) - set(BLIS_ENABLE_TRSM_PREINVERSION TRUE) -else() - add_definitions(-DBLIS_DISABLE_TRSM_PREINVERSION) -endif() +# Include functionality that returns header paths. +include(${CMAKE_SOURCE_DIR}/build/cmake/subdir_helper_functions.cmake) -if (ENABLE_INT_TYPE_SIZE) - set(BLIS_INT_TYPE_SIZE TRUE) - set (INT_TYPE_SIZE "64") -else () - set(BLIS_INT_TYPE_SIZE TRUE) - set (INT_TYPE_SIZE "32") -endif () +# If the CONFIG_LIST does not already contain the CONFIG_NAME (i.e., +# if CONFIG_NAME is an umbrella family), add in the corresponding +# directory. (In the next step, we will loop over the actual sub- +# configurations and add them as well.) +list(FIND CONFIG_LIST ${BLIS_CONFIG_FAMILY} IS_UMBRELLA) +if(${IS_UMBRELLA} STREQUAL "-1") + # Collect all subdirectory paths that have at least one file with suffix in ALL_H99_SUFS list. + get_dirpaths_with_suffixes(${BLIS_CONFIG_FAMILY}_HEADER_PATHS ${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY} "${ALL_H99_SUFS}") + # Collect all files in the subdirectories. + get_filepaths_with_suffixes(${BLIS_CONFIG_FAMILY}_HEADER_FILES ${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY} "${ALL_H99_SUFS}") +endif() +list(APPEND ALL_HEADER_PATHS_LIST "${${BLIS_CONFIG_FAMILY}_HEADER_PATHS}") +list(APPEND ALL_HEADER_FILES_LIST "${${BLIS_CONFIG_FAMILY}_HEADER_FILES}") -if (BLIS_ENABLE_ILP64 AND NOT ENABLE_INT_TYPE_SIZE) - message(FATAL_ERROR "for ILP64 we must enable ENABLE_INT_TYPE_SIZE with BLIS_INT_TYPE_SIZE = 64 ") -endif () +# Get header directory paths for each of the sub-configurations present +# in the configuration list. +foreach(CONF ${CONFIG_LIST}) + get_dirpaths_with_suffixes(config_${CONF}_HEADER_PATHS ${CMAKE_SOURCE_DIR}/config/${CONF} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_PATHS_LIST "${config_${CONF}_HEADER_PATHS}") + get_filepaths_with_suffixes(config_${CONF}_FILES_PATHS ${CMAKE_SOURCE_DIR}/config/${CONF} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_FILES_LIST "${config_${CONF}_HEADER_FILES}") +endforeach() -if (ENABLE_VERBOSE) - set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON" FORCE) -endif () +# Get header directory paths for each of the kernels present +# in the kernel list. +foreach(KERN ${KERNEL_LIST}) + # Collect all subdirectory paths that have at least one file with suffix in ALL_H99_SUFS list. + get_dirpaths_with_suffixes(kernels_${KERN}_HEADER_PATHS ${CMAKE_SOURCE_DIR}/kernels/${KERN} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_PATHS_LIST "${kernels_${KERN}_HEADER_PATHS}") + get_filepaths_with_suffixes(kernels_${KERN}_HEADER_FILES ${CMAKE_SOURCE_DIR}/kernels/${KERN} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_PATHS_FILES "${kernels_${KERN}_HEADER_FILES}") +endforeach() -if (ENABLE_JRIR_RR) - message("Round robin thread method enabled") - set(BLIS_ENABLE_JRIR_RR TRUE) - set(BLIS_ENABLE_JRIR_SLAB FALSE) -elseif (ENABLE_JRIR_SLAB) - message("SLAB thread method enabled") - set(BLIS_ENABLE_JRIR_SLAB TRUE) - set(BLIS_ENABLE_JRIR_RR FALSE) -else () - message("Unsupported method of thread partitioning in jr and ir loops") -endif () +# Get header directory paths for framework directory. +get_dirpaths_with_suffixes(frame_HEADER_PATHS ${CMAKE_SOURCE_DIR}/frame "${ALL_H99_SUFS}") +list(APPEND ALL_HEADER_PATHS_LIST "${frame_HEADER_PATHS}") +get_filepaths_with_suffixes(frame_HEADER_FILES ${CMAKE_SOURCE_DIR}/frame "${ALL_H99_SUFS}") +list(APPEND ALL_HEADER_FILES_LIST "${frame_HEADER_FILES}") -if (ENABLE_PBA_POOLS) - set(BLIS_ENABLE_PBA_POOLS TRUE) -endif () +# Get header directory paths for AOCL DTL logs directory. +get_dirpaths_with_suffixes(aocl_dtl_HEADER_PATHS ${CMAKE_SOURCE_DIR}/aocl_dtl "${ALL_H99_SUFS}") +list(APPEND ALL_HEADER_PATHS_LIST "${aocl_dtl_HEADER_PATHS}") +get_filepaths_with_suffixes(aocl_dtl_HEADER_FILES ${CMAKE_SOURCE_DIR}/aocl_dtl "${ALL_H99_SUFS}") +list(APPEND ALL_HEADER_FILES_LIST "${aocl_dtl_FILES_PATHS}") -if (ENABLE_SBA_POOLS) - set(BLIS_ENABLE_SBA_POOLS TRUE) -endif () +# Get a copy of the header paths without including the addons and the sandbox. +set(FRAME_HEADER_DIRPATHS_LIST ${ALL_HEADER_PATHS_LIST}) -if (ENABLE_MEM_TRACING) - set(BLIS_ENABLE_MEM_TRACING FALSE) -endif () +# Get header directory paths for each of the addons. +foreach(ADDON ${ENABLE_ADDON}) + get_dirpaths_with_suffixes(addon_${ADDON}_HEADER_PATHS ${CMAKE_SOURCE_DIR}/addon/${ADDON} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_PATHS_LIST "${addon_${ADDON}_HEADER_PATHS}") + get_filepaths_with_suffixes(addon_${ADDON}_HEADER_FILES ${CMAKE_SOURCE_DIR}/addon/${ADDON} "${ALL_H99_SUFS}") + list(APPEND ALL_HEADER_FILES_LIST "${addon_${ADDON}_HEADER_FILES}") +endforeach() -if (ENABLE_BLAS) - add_definitions(-DBLIS_ENABLE_BLAS) - set(BLIS_ENABLE_BLAS TRUE) -else () - add_definitions(-DBLIS_DISABLE_BLAS) - set(BLIS_ENABLE_BLAS FALSE) -endif () +# Pick up generated bli_config.h and bli_addon.h that get generated in +# current build directory. +list(PREPEND ALL_HEADER_PATHS_LIST ${PROJECT_BINARY_DIR}/) +list(PREPEND ALL_HEADER_FILES_LIST ${PROJECT_BINARY_DIR}/bli_config.h) +if(NOT (ENABLE_ADDON STREQUAL "")) + list(PREPEND ALL_HEADER_FILES_LIST ${PROJECT_BINARY_DIR}/bli_addon.h) +endif() -if (ENABLE_CBLAS) - add_definitions(-DBLIS_ENABLE_CBLAS) - set(BLIS_ENABLE_CBLAS TRUE) - if (NOT ENABLE_BLAS) - # Force BLAS layer when CBLAS is enabled - add_definitions(-DBLIS_ENABLE_BLAS) - set(BLIS_ENABLE_BLAS TRUE) - endif () -else () - add_definitions(-DBLIS_DISABLE_CBLAS) - set(BLIS_ENABLE_CBLAS FALSE) -endif () +# Create a string out of this list so that it can be processed by flatten-headers.py. +list(JOIN ALL_HEADER_PATHS_LIST " " ALL_HEADER_PATHS_STRING) -if (ENABLE_BLASTEST) - add_definitions(-DBLIS_ENABLE_BLAS) - add_definitions(-DBLIS_ENABLE_CBLAS) +#-------------------------------------------- +# Consolidated blis.h header creation +#-------------------------------------------- +# Creating a directory for the generated flatten headers. +file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}) +# Flatten header python script file which expand header contents in blis.h. +add_custom_command(OUTPUT ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/flatten-headers.py -c -v1 + "${CMAKE_SOURCE_DIR}/frame/include/blis.h" + "${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h" + "${PROJECT_BINARY_DIR}/include" + "${ALL_HEADER_PATHS_STRING}" + COMMENT "Generating monolithic blis header file: ${CMAKE_SOURCE_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h" + DEPENDS ${ALL_HEADER_FILES_LIST} + ) +add_custom_target(flat-header DEPENDS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h) +#-------------------------------------------- +# Consolidated cblas.h header creation +#-------------------------------------------- +# Flatten header python script file which expand header contents in cblas.h. +if(ENABLE_CBLAS) + add_custom_command(OUTPUT ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/flatten-headers.py -c -v1 + "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/cblas.h" + "${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h" + "${PROJECT_BINARY_DIR}/${include}" + "${ALL_HEADER_PATHS_STRING}" + COMMENT "Generating monolithic cblas header file: ${CMAKE_SOURCE_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h" + DEPENDS ${ALL_HEADER_FILES_LIST} + ) + add_custom_target(flat-cblas-header DEPENDS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h) endif() -if (ENABLE_TESTCPP_TESTING) - add_definitions(-DBLIS_ENABLE_BLAS) - add_definitions(-DBLIS_ENABLE_CBLAS) -endif () +#-------------------------------------------- +# Default linker definitions +#-------------------------------------------- +# NOTE: This section needs to reside before the inclusion of make_defs.mk +# files (just below), as most configurations' make_defs.mk don't tinker +# with things like LDFLAGS, but some do (or may), in which case they can +# manually override whatever they need. -if (ENABLE_MIXED_DT) - set(BLIS_ENABLE_MIXED_DT TRUE) -endif () +# Define the external libraries we may potentially need at link-time. +# Add libm only on Linux and only if Intel compiler is not used. +if((NOT WIN32) AND (NOT ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Intel"))) + set(LIBM -lm) +endif() +set(LIBMEMKIND -lmemkind) -if (ENABLE_MIXED_DT_EXTRA_MEM) - set(BLIS_ENABLE_MIXED_DT_EXTRA_MEM TRUE) -endif () +# Default linker flags. +# NOTE: -lpthread is needed unconditionally because BLIS uses pthread_once() +# to initialize itself in a thread-safe manner. The one exception to this +# rule: if --disable-system is given at configure-time, LIBPTHREAD is empty. +if(NOT WIN32) + set(LDFLAGS ${LIBM} ${LIBPTHREAD}) +endif() +# Add libmemkind to the link-time flags, if it was enabled at configure-time. +if(ENABLE_MEMKIND STREQUAL "yes") + list(APPEND LDFLAGS ${LIBMEMKIND}) +endif() -if (ENABLE_SUP_HANDLING) - set(BLIS_ENABLE_SUP_HANDLING TRUE) -endif () +#-------------------------------------------- +# Configuration-agnostic flags +#-------------------------------------------- +# --- Warning flags --- -if (ENABLE_MEMKIND) - set(BLIS_ENABLE_MEMKIND FALSE) -endif () +# Disable unused function warnings and stop compiling on first error for +# all compilers that accept such options: gcc, clang, and icc. +set(CWARNFLAGS -Wno-unused-function -Wfatal-errors) +if(NOT WIN32) + list(PREPEND CWARNFLAGS -Wall) +endif() -if (ENABLE_PRAGMA_OMP_SIMD) - set(BLIS_ENABLE_PRAGMA_OMP_SIMD TRUE) -endif () +# Disable tautological comparision warnings in clang. +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + list(APPEND CWARNFLAGS -Wno-tautological-compare -Wno-pass-failed) +endif() -if (ENABLE_SANDBOX) - set(BLIS_ENABLE_SANDBOX FALSE) -endif () +# Add extra warning flags for Windows builds. +if(WIN32) + list(APPEND CWARNFLAGS -Wno-unused-variable -Wno-deprecated-declarations) +endif() -include_directories(${PROJECT_SOURCE_DIR}/external/msvc) -add_definitions(-D_CRT_SECURE_NO_WARNINGS) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MD ") -#add_definitions(-DBLIS_IS_BUILDING_LIBRARY) -if(NOT BUILD_SHARED_LIBS) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MT ") -add_definitions(-DBLIS_IS_BUILDING_LIBRARY) -endif() - -if(ENABLE_MULTITHREADING) - if(BUILD_SHARED_LIBS) - set(LIB_NAME "${PROJECT_NAME}-MT-dll") - elseif(NOT BUILD_SHARED_LIBS) - set(LIB_NAME "${PROJECT_NAME}-MT") - endif() - if(ENABLE_OPENMP) - find_package(OpenMP) - if (OPENMP_FOUND) - set(BLIS_ENABLE_OPENMP TRUE) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") +#Setting up the correct Windows Runtime Library. +if(WIN32) + cmake_policy(SET CMP0091 NEW) + if(BUILD_SHARED_LIBS) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") else() - message (FATAL_ERROR "Openmp Not Found") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") endif() - endif() -else() - if(BUILD_SHARED_LIBS) - set(LIB_NAME "${PROJECT_NAME}-dll") - elseif(NOT BUILD_SHARED_LIBS) - set(LIB_NAME "${PROJECT_NAME}") - endif() endif() -if(DISABLE_BLIS_ARCH_TYPE) - set(BLIS_DISABLE_BLIS_ARCH_TYPE TRUE) - set(BLIS_DISABLE_BLIS_MODEL_TYPE TRUE) -else() - set(BLIS_DISABLE_BLIS_ARCH_TYPE FALSE) - set(BLIS_DISABLE_BLIS_MODEL_TYPE FALSE) +# --- Symbol exporting flags (shared libraries only) -- + +# NOTE: These flags are only applied when building BLIS and not used by +# applications. + +# Determine default export behavior / visibility of symbols for gcc, icc and clang. +if(NOT WIN32) + if(EXPORT_SHARED STREQUAL "all") + # Export all symbols by default. + set(BUILD_SYMFLAGS -fvisibility=default) + else() # ifeq ($(EXPORT_SHARED),public) + # Hide all symbols by default and export only those that have been annotated + # as needing to be exported. + set(BUILD_SYMFLAGS -fvisibility=hidden) + endif() endif() -if(RENAME_BLIS_ARCH_TYPE) - set(__blis_arch_type_name TRUE) - set(rename_blis_arch_type "${RENAME_BLIS_ARCH_TYPE}") -else() - set(__blis_arch_type_name TRUE) - set(rename_blis_arch_type "BLIS_ARCH_TYPE") +# --- C Preprocessor flags --- +# Enable clock_gettime() in time.h. +set(CPPROCFLAGS -D_POSIX_C_SOURCE=200112L) + +# --- Threading flags --- +# NOTE: We don't have to explicitly omit -pthread when --disable-system is given +# since that option forces --enable-threading=none, and thus -pthread never gets +# added to begin with. +if(NOT WIN32) + if(THREADING_MODEL STREQUAL "pthreads") + set(CTHREADFLAGS "-pthread") + endif() endif() -if(RENAME_BLIS_MODEL_TYPE) - set(__blis_model_type_name TRUE) - set(rename_blis_model_type "${RENAME_BLIS_MODEL_TYPE}") -else() - set(__blis_model_type_name TRUE) - set(rename_blis_model_type "BLIS_MODEL_TYPE") +# --- #pragma omp simd flags (used for reference kernels only) --- +if(PRAGMA_OMP_SIMD) + if(WIN32) + set(COMPSIMDFLAGS /openmp:experimental) + else() + set(COMPSIMDFLAGS -fopenmp-simd) + endif() endif() -find_package(Doxygen) -set(W_DIR "${CMAKE_CURRENT_SOURCE_DIR}/docs") -if(NOT (DOXYGEN_FOUND)) - message(STATUS "Doxygen not found please install and try again.") -else() - execute_process(COMMAND doxygen Doxyfile - WORKING_DIRECTORY ${W_DIR} - COMMAND_ECHO STDOUT) +#-------------------------------------------- +# Compiler include path definitions +#-------------------------------------------- +# Obtain a list of header files #included inside of the bli_cntx_ref.c file. +# Due to the way that bli_cntx_ref.c uses headers and macros, paths to these +# files will be needed when compiling bli_cntx_ref.c with the monolithic header. + +# Read content of bli_cntx_ref.c and put it in REF_KER_HEADERS_TEMP. +file(STRINGS ${CMAKE_SOURCE_DIR}/ref_kernels/bli_cntx_ref.c REF_KER_HEADERS_TEMP) +# Only keep the lines where there are includes. +list(FILTER REF_KER_HEADERS_TEMP INCLUDE REGEX "\#include") +# REF_KER_HEADERS has a list of all files that are included in bli_cntx_ref.c. +set(REF_KER_HEADERS "") +foreach(header ${REF_KER_HEADERS_TEMP}) + string(REGEX MATCH "\#include [\"<]\([a-zA-Z0-9\_\.\/\-]*\)[\">].*" helper ${header}) + list(APPEND REF_KER_HEADERS ${CMAKE_MATCH_1}) +endforeach() +# Remove blis.h from the list. +list(FILTER REF_KER_HEADERS EXCLUDE REGEX "blis.h") +set(REF_KER_H_PATHS "") +foreach(header_name ${REF_KER_HEADERS}) + foreach(header_dir ${FRAME_HEADER_DIRPATHS_LIST}) + if(EXISTS ${header_dir}/${header_name}) + list(APPEND REF_KER_H_PATHS ${header_dir}) + break() + endif() + endforeach() +endforeach() +# Remove duplicates, if they exist. +list(REMOVE_DUPLICATES REF_KER_H_PATHS) + +# Create list of include directories, to be used while creating the library. +# NOTE: We no longer need every header path in the source tree since we +# now #include the monolithic/flattened blis.h instead. +set(CINFLAGS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}) +list(APPEND CINFLAGS ${REF_KER_H_PATHS}) +# Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h. +list(APPEND CINFLAGS ${CMAKE_SOURCE_DIR}/frame/include) +# If CBLAS is enabled, we also include the path to the cblas.h directory so +# that the compiler will be able to find cblas.h as the CBLAS source code is +# being compiled. +if(ENABLE_CBLAS) + set(CBLAS_H_DIRPATH "") + foreach(header_dir ${FRAME_HEADER_DIRPATHS_LIST}) + if(EXISTS ${header_dir}/cblas.h) + list(APPEND CBLAS_H_DIRPATH ${header_dir}) + break() + endif() + endforeach() + list(APPEND CINFLAGS ${CBLAS_H_DIRPATH}) endif() -if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/docs/html/index.html) - message(STATUS "Documentation generated successfully, to view documentation open docs/html/index.html .") -else() - message(STATUS "Document generation failed.") -endif() - -set(CMAKE_BUILD_TYPE ${CMAKE_CONFIGURATION_TYPES}) - -#print configurations -message("---cmake configurations---") -message(CMAKE_C_COMPILER_ID : ${CMAKE_C_COMPILER_ID}) -message(CMAKE_BUILD_TYPE : ${CMAKE_BUILD_TYPE}) -message(BLIS_ENABLE_OPENMP : ${BLIS_ENABLE_OPENMP}) -message(BLIS_ENABLE_JRIR_SLAB : ${BLIS_ENABLE_JRIR_SLAB}) -message(BLIS_ENABLE_JRIR_RR : ${BLIS_ENABLE_JRIR_RR}) -message(BLIS_ENABLE_PBA_POOLS : ${BLIS_ENABLE_PBA_POOLS}) -message(BLIS_ENABLE_SBA_POOLS : ${BLIS_ENABLE_SBA_POOLS}) -message(BLIS_ENABLE_MEM_TRACING : ${BLIS_ENABLE_MEM_TRACING}) -message(BLIS_INT_TYPE_SIZE : ${BLIS_INT_TYPE_SIZE}) -message(BLIS_BLAS_INT_TYPE_SIZE : ${BLIS_BLAS_INT_TYPE_SIZE}) -message(BLIS_ENABLE_BLAS : ${BLIS_ENABLE_BLAS}) -message(BLIS_ENABLE_CBLAS : ${BLIS_ENABLE_CBLAS}) -message(BLIS_ENABLE_MIXED_DT : ${BLIS_ENABLE_MIXED_DT}) -message(BLIS_ENABLE_MIXED_DT_EXTRA_MEM : ${BLIS_ENABLE_MIXED_DT_EXTRA_MEM}) -message(BLIS_ENABLE_SUP_HANDLING : ${BLIS_ENABLE_SUP_HANDLING}) -message(BLIS_ENABLE_MEMKIND : ${BLIS_ENABLE_MEMKIND}) -message(BLIS_ENABLE_PRAGMA_OMP_SIMD : ${BLIS_ENABLE_PRAGMA_OMP_SIMD}) -message(BLIS_ENABLE_SANDBOX : ${BLIS_ENABLE_SANDBOX}) -message(BLIS_ENABLE_SHARED : ${BLIS_ENABLE_SHARED}) -message(DISABLE_BLIS_ARCH_TYPE : ${DISABLE_BLIS_ARCH_TYPE}) -message(RENAME_BLIS_ARCH_TYPE : ${RENAME_BLIS_ARCH_TYPE}) -message(RENAME_BLIS_MODEL_TYPE : ${RENAME_BLIS_MODEL_TYPE}) - -SET(ENABLE_SIMD_FLAGS "none" CACHE STRING "Set compiler SIMD flags") -SET_PROPERTY(CACHE ENABLE_SIMD_FLAGS PROPERTY STRINGS none SSE2 AVX AVX2) - -if(${ENABLE_SIMD_FLAGS} MATCHES "AVX2") - add_definitions(/arch:AVX2) -elseif(${ENABLE_SIMD_FLAGS} MATCHES "AVX") - add_definitions(/arch:AVX) -elseif(${ENABLE_SIMD_FLAGS} MATCHES "SSE2") - add_definitions(/arch:SSE2) -endif() - -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP") -set(INTR_GENERAL_LINK_FLAGS "${INTR_GENERAL_LINK_FLAGS} /RELEGE") - -add_definitions(-D_CRT_SECURE_NO_DEPRECATE) - -#add_definitions(-DBLIS_OS_WINDOWS) -add_definitions(-D_MSC_VER) -if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") -else() -add_definitions(-DBLIS_CNAME=${TARGET_ARCH}) -endif() -# Generate the bli_config.h header file -configure_file (build/bli_win_config.h.in ${CMAKE_SOURCE_DIR}/bli_config.h @ONLY) - -include_directories(${CMAKE_SOURCE_DIR}/aocl_dtl) -include_directories(${CMAKE_SOURCE_DIR}/.) -include_directories(${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}) -include_directories(${CMAKE_SOURCE_DIR}/frame/include) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/1e) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/1m) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/1r) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/bb) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/io) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/ri) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/ri3) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/rih) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/ro) -include_directories(${CMAKE_SOURCE_DIR}/frame/include/level0/rpi) -include_directories(${CMAKE_SOURCE_DIR}/frame/thread) -include_directories(${CMAKE_SOURCE_DIR}/frame/base) -include_directories(${CMAKE_SOURCE_DIR}/frame/base/cast) -include_directories(${CMAKE_SOURCE_DIR}/frame/base/check) -include_directories(${CMAKE_SOURCE_DIR}/frame/base/noopt) -include_directories(${CMAKE_SOURCE_DIR}/frame/base/proj) -include_directories(${CMAKE_SOURCE_DIR}/frame/0) -include_directories(${CMAKE_SOURCE_DIR}/frame/0/copysc) -include_directories(${CMAKE_SOURCE_DIR}/frame/1) -include_directories(${CMAKE_SOURCE_DIR}/frame/1d) -include_directories(${CMAKE_SOURCE_DIR}/frame/1f) -include_directories(${CMAKE_SOURCE_DIR}/frame/1m) -include_directories(${CMAKE_SOURCE_DIR}/frame/1m/packm) -include_directories(${CMAKE_SOURCE_DIR}/frame/1m/unpackm) -include_directories(${CMAKE_SOURCE_DIR}/frame/2) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/gemv) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/ger) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/hemv) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/her) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/her2) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/symv) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/syr) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/syr2) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/trmv) -include_directories(${CMAKE_SOURCE_DIR}/frame/2/trsv) -include_directories(${CMAKE_SOURCE_DIR}/frame/3) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/gemm) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/gemm/ind) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/gemmt) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/hemm) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/her2k) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/herk) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/symm) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/syr2k) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/syrk) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/trmm) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/trmm3) -include_directories(${CMAKE_SOURCE_DIR}/frame/3/trsm) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/cblas) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/cblas/f77_sub) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/cblas/src) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/check) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/f2c) -include_directories(${CMAKE_SOURCE_DIR}/frame/compat/f2c/util) -include_directories(${CMAKE_SOURCE_DIR}/frame/ind) -include_directories(${CMAKE_SOURCE_DIR}/frame/ind/cntx) -include_directories(${CMAKE_SOURCE_DIR}/frame/ind/oapi) -include_directories(${CMAKE_SOURCE_DIR}/frame/ind/tapi) -include_directories(${CMAKE_SOURCE_DIR}/frame/ind/ukernels) -include_directories(${CMAKE_SOURCE_DIR}/frame/util) -include_directories(${CMAKE_SOURCE_DIR}/config/generic) -include_directories(${CMAKE_SOURCE_DIR}/config/zen) -include_directories(${CMAKE_SOURCE_DIR}/config/zen2) -include_directories(${CMAKE_SOURCE_DIR}/config/zen3) -include_directories(${CMAKE_SOURCE_DIR}/config/zen4) -if(${AOCL_BLIS_FAMILY} STREQUAL "amdzen") - include_directories(${CMAKE_BINARY_DIR}/ref_kernels/generic) - include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen) - include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen2) - include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen3) - include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen4) -endif() -include_directories(${CMAKE_SOURCE_DIR}/ref_kernels) -include_directories(${CMAKE_SOURCE_DIR}/kernels) -include_directories(${CMAKE_SOURCE_DIR}/kernels/haswell) -include_directories(${CMAKE_SOURCE_DIR}/kernels/haswell/3) -include_directories(${CMAKE_SOURCE_DIR}/kernels/haswell/3/sup) -include_directories(${CMAKE_SOURCE_DIR}/kernels/haswell/3/sup/d6x8) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/1) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/1f) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/1m) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/2) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3/sup) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen2) -include_directories(${CMAKE_SOURCE_DIR}/kernels/zen4) -include_directories(${CMAKE_SOURCE_DIR}/kernels/skx) -include_directories(${CMAKE_SOURCE_DIR}/kernels/skx/3) -file(GLOB headers ${CMAKE_SOURCE_DIR}/*.h) - -# Monolithic Header generation -find_package(PythonLibs 3 REQUIRED) - -string(APPEND HEADER_PATH -if(${AOCL_BLIS_FAMILY} STREQUAL "zen") - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" -elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/amdzen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen3/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen4/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/generic/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" -endif () - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/0/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/0/copysc/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1d/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1f/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1m/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1m/packm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/1m/unpackm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/gemv/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/ger/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/hemv/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/her/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/her2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/symv/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/syr/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/syr2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/trmv/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/2/trsv/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/gemm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/gemm/ind/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/gemmt/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/hemm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/her2k/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/herk/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/symm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/syr2k/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/syrk/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/trmm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/trmm3/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/3/trsm/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/base/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/base/cast/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/base/check/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/base/noopt/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/base/proj/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/cblas/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/cblas/f77_sub/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/cblas/src/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/check/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/f2c/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/compat/f2c/util/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/1e/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/1m/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/1r/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/bb/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/io/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/ri/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/ri3/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/rih/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/ro/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/include/level0/rpi/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/ind/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/ind/cntx/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/ind/oapi/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/ind/tapi/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/ind/ukernels/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/thread/" - " ${CMAKE_CURRENT_SOURCE_DIR}/frame/util/" - " ${CMAKE_CURRENT_SOURCE_DIR}/aocl_dtl/" - " ${CMAKE_CURRENT_SOURCE_DIR}/" -) -file(MAKE_DIRECTORY ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}) +#-------------------------------------------- +# Special preprocessor macro definitions +#-------------------------------------------- +# Define a C preprocessor macro to communicate the current version so that it +# can be embedded into the library and queried later. +set(VERS_DEF -DBLIS_VERSION_STRING="${VERSION_STRING}") -# Flatten header python script file which expand header contents in blis.h -set(FLATTEN_PY "${CMAKE_SOURCE_DIR}/build/flatten-headers.py") -set(BLIS_H "blis.h") +# Define a C preprocessor flag that is *only* defined when BLIS is being +# compiled. (In other words, an application that #includes blis.h will not +# get this cpp macro.) +set(BUILD_CPPFLAGS -DBLIS_IS_BUILDING_LIBRARY) -# Arguements for python script -set(C_COMMENT "-c") -set(VERBOSE "-v1") -set(INPUT "${CMAKE_SOURCE_DIR}/frame/include/${BLIS_H}") -set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${BLIS_H}") -set(TEMP_DIR "${INCLUDE}") -set(DIR_H_PATH "${HEADER_PATH}") +#-------------------------------------------- +# Add CMakeLists.txt from directories +#-------------------------------------------- +# Add config subdirectory. +add_subdirectory(config) +# Add kernel subdirectory. +add_subdirectory(kernels) +# Add framework directory. +add_subdirectory(frame) +# Add AOCL DTL logs directory. +add_subdirectory(aocl_dtl) +# Add subdirectory for each of the addons. +list(LENGTH ENABLE_ADDON addon_list_size) +if(addon_list_size GREATER 0) + add_subdirectory(addon) +endif() -# Run python script to generate monolithic header at configuration time -execute_process( - COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" - RESULT_VARIABLE CMD_RESULT - OUTPUT_VARIABLE CMD_OUTPUT) -message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) - -# Logic to generate the cblas.h in include folder. -set(CBLAS_H "cblas.h") -# Arguements for python script -set(C_COMMENT "-c") -set(VERBOSE "-v1") -set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") -set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") -set(TEMP_DIR "${INCLUDE}") -set(DIR_H_PATH "${HEADER_PATH}") - -# Run python script to generate monolithic header at configuration time -execute_process( - COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" - RESULT_VARIABLE CMD_RESULT - OUTPUT_VARIABLE CMD_OUTPUT) -message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) - -# setting the blis version string -file (STRINGS "version" BLIS_VERSION) -set(BLIS_VERSION_STRING ${BLIS_VERSION}) -string(TIMESTAMP BUILD_DATE "%Y%m%d") -add_definitions(-DBLIS_VERSION_STRING="AOCL-BLIS ${BLIS_VERSION_STRING} Build ${BUILD_DATE}") - -# Set object libraries created in kernels directory to be added into BLIS library. -set(OBJECT_LIBRARIES - $ - $ - $ - $ - $ - $ - $ - $ - $ +# Collect all object libraries that are required to build the blis library. +set(OBJECT_LIBRARIES "") +# Add objects from config. +foreach(conf ${CONFIG_LIST}) + list(APPEND OBJECT_LIBRARIES $) +endforeach() +# Add objects from kernels. +foreach(ker ${KERNEL_LIST}) + if(TARGET ${ker}_KERNELS) + list(APPEND OBJECT_LIBRARIES $) + endif() +endforeach() +# Add objects for reference kernels. +foreach(conf ${CONFIG_LIST}) + list(APPEND OBJECT_LIBRARIES $) + list(APPEND OBJECT_LIBRARIES $) +endforeach() +# Add objects for frame. +list(APPEND OBJECT_LIBRARIES $) +# Add objects for aocl-dtl. +list(APPEND OBJECT_LIBRARIES $) +# Add objects for addons. +foreach(addon ${ENABLE_ADDON}) + if(TARGET ${addon}_C99_ADDON) + list(APPEND OBJECT_LIBRARIES $) + endif() + if(TARGET ${addon}_C99_KERNEL_ADDON) + list(APPEND OBJECT_LIBRARIES $) + endif() + if(TARGET ${addon}_CXX_ADDON) + list(APPEND OBJECT_LIBRARIES $) + endif() +endforeach() + +#-------------------------------------------- +# Building BLIS Library +#-------------------------------------------- +# Public blis headers. +set(BLIS_PUBLIC_HEADERS + ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/blis.h + # Include AMD's C++ template header files in the list of headers + # to install. + ${CMAKE_SOURCE_DIR}/vendor/cpp/blis.hh + ${CMAKE_SOURCE_DIR}/vendor/cpp/cblas.hh ) -# Ammend the list of object libraries to include zen4 paths as appropriate. -if(${TARGET_ARCH} STREQUAL zen4 OR - ${TARGET_ARCH} STREQUAL amdzen) - set(OBJECT_LIBRARIES ${OBJECT_LIBRARIES} - $ - $ - $ - $ - $ - $ - ) +if(ENABLE_CBLAS) + list(APPEND BLIS_PUBLIC_HEADERS ${PROJECT_BINARY_DIR}/include/${BLIS_CONFIG_FAMILY}/cblas.h) endif() -if(BUILD_SHARED_LIBS) - add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h - ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h - ${headers} - ${OBJECT_LIBRARIES} - ) - if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) - endif() - target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") -endif() -if(NOT BUILD_SHARED_LIBS) - add_library("${PROJECT_NAME}" STATIC ${CMAKE_SOURCE_DIR}/bli_config.h - ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h - ${headers} - ${OBJECT_LIBRARIES} - ) - if(ENABLE_OPENMP) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") +# --- Library name and local paths --- +# From old CMake +if(WIN32) + add_definitions(-D_CRT_SECURE_NO_WARNINGS) + add_definitions(-D_CRT_SECURE_NO_DEPRECATE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP${CMake_MSVC_PARALLEL}") + set(INTR_GENERAL_LINK_FLAGS "${INTR_GENERAL_LINK_FLAGS} /RELEGE") + add_definitions(-DEXPMODULE) +endif() + +# Set up the library name. +if(WIN32) + set(LIBBLIS AOCL-LibBlis-Win) +else() + set(LIBBLIS blis) +endif() + +# Append if threading is required. +if(NOT (THREADING_MODEL STREQUAL "no")) + if(WIN32) + string(APPEND LIBBLIS -MT) else() - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") + string(APPEND LIBBLIS -mt) + endif() +endif() + +if(BUILD_SHARED_LIBS) + if(WIN32) + string(APPEND LIBBLIS -dll) + endif() + # Build shared library. + add_library(libblis SHARED ${OBJECT_LIBRARIES}) + target_link_libraries(libblis PRIVATE ${LDFLAGS}) + set_target_properties(libblis PROPERTIES LINKER_LANGUAGE C VERSION ${VERSION} SOVERSION ${SO_VERSION_MAJOR}) + set_target_properties(libblis PROPERTIES POSITION_INDEPENDENT_CODE ON) + if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(libblis PRIVATE OpenMP::OpenMP_C) endif() +else() + # Build static library. + add_library(libblis STATIC ${OBJECT_LIBRARIES}) + set_target_properties(libblis PROPERTIES LINKER_LANGUAGE C) +endif() +add_dependencies(libblis flat-header) +if(ENABLE_CBLAS) + add_dependencies(libblis flat-cblas-header) endif() +# Add headers as a property to the library. +set_target_properties(libblis PROPERTIES PUBLIC_HEADER "${BLIS_PUBLIC_HEADERS}") +set_target_properties(libblis PROPERTIES OUTPUT_NAME ${LIBBLIS}) -link_directories(${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) -add_definitions(-DEXPMODULE) +# Install targets. +install(TARGETS libblis LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/lib + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}/include) -add_subdirectory(config) -add_subdirectory(ref_kernels) -add_subdirectory(kernels) -add_subdirectory(frame) -add_subdirectory(aocl_dtl) -add_subdirectory(test) -add_subdirectory(testsuite) -add_subdirectory(bench) -if(ENABLE_TESTCPP_TESTING) - add_subdirectory(vendor/testcpp) +# --- Primary targets --- +add_custom_target(libs DEPENDS libblis) + +# Multiple BLIS API testing targets. Result files are generated in ${CMAKE_BINARY_DIR}/testsuite. +add_subdirectory(testsuite EXCLUDE_FROM_ALL) + +# Check results of BLIS CPP Template tests +add_subdirectory(vendor/testcpp EXCLUDE_FROM_ALL) + +# Add BLAS tests if BLAS interface is enabled. +if(ENABLE_BLAS) + add_subdirectory(blastest EXCLUDE_FROM_ALL) endif() -if (ENABLE_BLASTEST) - add_subdirectory(blastest) + +# Add generic testing target `test`. +set(available_testsuites checkblis) +if(ENABLE_BLAS) + list(APPEND available_testsuites checkblas) +endif() +add_custom_target(test DEPENDS ${available_testsuites}) + +# Add generic testing target `check`. +set(available_testsuites checkblis-fast) +if(ENABLE_BLAS) + list(APPEND available_testsuites checkblas) endif() +add_custom_target(check DEPENDS ${available_testsuites}) \ No newline at end of file diff --git a/CREDITS b/CREDITS index fd0bcb5b32..bcbc889ce0 100644 --- a/CREDITS +++ b/CREDITS @@ -42,6 +42,7 @@ but many others have contributed code and feedback, including Shivaprashanth H (Global Edge) Jean-Michel Hautbois @jhautbois Ian Henriksen @insertinterestingnamehere (The University of Texas at Austin) + Greg Henry (Intel) Minh Quan Ho @hominhquan Matthew Honnibal @honnibal Stefan Husmann @stefanhusmann @@ -50,9 +51,11 @@ but many others have contributed code and feedback, including Tony Kelman @tkelman Lee Killough @leekillough (Cray) Mike Kistler @mkistler (IBM, Austin Research Laboratory) + Ivan Korostelev @ivan23kor (University of Alberta) Kyungmin Lee @kyungminlee (Ohio State University) Michael Lehn @michael-lehn Shmuel Levine @ShmuelLevine + @lschork2 Dave Love @loveshack Tze Meng Low (The University of Texas at Austin) Ye Luo @ye-luo (Argonne National Laboratory) @@ -92,6 +95,7 @@ but many others have contributed code and feedback, including Paul Springer @springer13 (RWTH Aachen University) Adam J. Stewart @adamjstewart (University of Illinois at Urbana-Champaign) Vladimir Sukarev + Chengguo Sun @chengguosun Santanu Thangaraj (AMD) Nicholai Tukanov @nicholaiTukanov (The University of Texas at Austin) Rhys Ulerich @RhysU (The University of Texas at Austin) @@ -99,6 +103,7 @@ but many others have contributed code and feedback, including Meghana Vankadari @Meghana-vankadari (AMD) Kiran Varaganti @kvaragan (AMD) Natalia Vassilieva (Hewlett Packard Enterprise) + Andrew Wildman @awild82 (University of Washington) Zhang Xianyi @xianyi (Chinese Academy of Sciences) Benda Xu @heroxbd Guodong Xu @docularxu (Linaro.org) @@ -106,6 +111,7 @@ but many others have contributed code and feedback, including Costas Yamin @cosstas Chenhan Yu @ChenhanYu (The University of Texas at Austin) Roman Yurchak @rth (Symerio) + Stefano Zampini @stefanozampini M. Zhou @cdluminate BLIS's development was partially funded by grants from industry diff --git a/LICENSE b/LICENSE index be24a09734..f05ca1125c 100644 --- a/LICENSE +++ b/LICENSE @@ -15,7 +15,7 @@ copyright info. All parties provide their portions of the code under the Copyright (C) 2018, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. +Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/Makefile b/Makefile index 0a1a4646ad..4c4c01ffd0 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -320,6 +320,7 @@ BLASTEST_INPUT_PATH := $(DIST_PATH)/$(BLASTEST_DIR)/input # The location of the BLAS test suite object directory. BASE_OBJ_BLASTEST_PATH := $(BASE_OBJ_PATH)/$(BLASTEST_DIR) +BASE_EXE_BLASTEST_PATH := $(BASE_OBJ_BLASTEST_PATH)/$(MK_USE_LIB) # The locations of the BLAS test suite source code (f2c and drivers). BLASTEST_F2C_SRC_PATH := $(DIST_PATH)/$(BLASTEST_DIR)/f2c @@ -347,7 +348,7 @@ BLASTEST_DRV_BASES := $(basename $(notdir $(BLASTEST_DRV_OBJS))) # The binary executable driver names. BLASTEST_DRV_BINS := $(addsuffix .x,$(BLASTEST_DRV_BASES)) -BLASTEST_DRV_BIN_PATHS := $(addprefix $(BASE_OBJ_BLASTEST_PATH)/,$(BLASTEST_DRV_BINS)) +BLASTEST_DRV_BIN_PATHS := $(addprefix $(BASE_EXE_BLASTEST_PATH)/,$(BLASTEST_DRV_BINS)) # Binary executable driver "run-" names BLASTEST_DRV_BINS_R := $(addprefix run-,$(BLASTEST_DRV_BASES)) @@ -393,6 +394,7 @@ TESTSUITE_SALT_OPS_PATH := $(DIST_PATH)/$(TESTSUITE_DIR)/$(TESTSUITE_SALT_OPS) # directory. TESTSUITE_SRC_PATH := $(DIST_PATH)/$(TESTSUITE_DIR)/src BASE_OBJ_TESTSUITE_PATH := $(BASE_OBJ_PATH)/$(TESTSUITE_DIR) +BASE_EXE_TESTSUITE_PATH := $(BASE_OBJ_PATH)/$(TESTSUITE_DIR)/$(MK_USE_LIB) # Convert source file paths to object file paths by replacing the base source # directories with the base object directories, and also replacing the source @@ -414,7 +416,7 @@ MK_TESTSUITE_OBJS := $(sort \ # unusual environments (e.g. ARM) can run the testsuite through some other # binary. See .travis.yml for details on how the variable is employed in # practice. -TESTSUITE_BIN := test_$(LIBBLIS).x +TESTSUITE_BIN := $(BASE_EXE_TESTSUITE_PATH)/test_$(LIBBLIS).x TESTSUITE_WRAPPER ?= # The location of the script that checks the BLIS testsuite output. @@ -504,7 +506,7 @@ endif flat-header: check-env $(BLIS_H_FLAT) -$(BLIS_H_FLAT): $(FRAME_H99_FILES) +$(BLIS_H_FLAT): $(ALL_H99_FILES) ifeq ($(ENABLE_VERBOSE),yes) $(FLATTEN_H) -c -v1 $(BLIS_H_SRC_PATH) $@ "./$(INCLUDE_DIR)" "$(ALL_H99_DIRPATHS)" else @@ -820,7 +822,7 @@ blastest-bin: check-env blastest-f2c $(BLASTEST_DRV_BIN_PATHS) blastest-run: $(BLASTEST_DRV_BINS_R) # f2c object file rule. -$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_F2C_SRC_PATH)/%.c +$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_F2C_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) $(BLAT_CFLAGS) -c $< -o $@ else @@ -829,7 +831,7 @@ else endif # driver object file rule. -$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_DRV_SRC_PATH)/%.c +$(BASE_OBJ_BLASTEST_PATH)/%.o: $(BLASTEST_DRV_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) $(BLAT_CFLAGS) -c $< -o $@ else @@ -850,7 +852,8 @@ endif # first argument: the base name of the BLAS test driver. define make-blat-rule -$(BASE_OBJ_BLASTEST_PATH)/$(1).x: $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) +$(BASE_EXE_BLASTEST_PATH)/$(1).x: $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) + @mkdir -p $(BASE_EXE_BLASTEST_PATH) ifeq ($(ENABLE_VERBOSE),yes) $(LINKER) $(BASE_OBJ_BLASTEST_PATH)/$(1).o $(BLASTEST_F2C_LIB) $(LIBBLIS_LINK) $(LDFLAGS) -o $$@ else @@ -864,12 +867,12 @@ $(foreach name, $(BLASTEST_DRV_BASES), $(eval $(call make-blat-rule,$(name)))) # A rule to run ?blat1.x driver files. define make-run-blat1-rule -run-$(1): $(BASE_OBJ_BLASTEST_PATH)/$(1).x +run-$(1): $(BASE_EXE_BLASTEST_PATH)/$(1).x ifeq ($(ENABLE_VERBOSE),yes) - $(TESTSUITE_WRAPPER) $(BASE_OBJ_BLASTEST_PATH)/$(1).x > out.$(1) + $(TESTSUITE_WRAPPER) $(BASE_EXE_BLASTEST_PATH)/$(1).x > out.$(1) else @echo "Running $(1).x > 'out.$(1)'" - @$(TESTSUITE_WRAPPER) $(BASE_OBJ_BLASTEST_PATH)/$(1).x > out.$(1) + @$(TESTSUITE_WRAPPER) $(BASE_EXE_BLASTEST_PATH)/$(1).x > out.$(1) endif endef @@ -878,12 +881,12 @@ $(foreach name, $(BLASTEST_DRV1_BASES), $(eval $(call make-run-blat1-rule,$(name # A rule to run ?blat2.x and ?blat3.x driver files. define make-run-blat23-rule -run-$(1): $(BASE_OBJ_BLASTEST_PATH)/$(1).x +run-$(1): $(BASE_EXE_BLASTEST_PATH)/$(1).x ifeq ($(ENABLE_VERBOSE),yes) - $(TESTSUITE_WRAPPER) $(BASE_OBJ_BLASTEST_PATH)/$(1).x < $(BLASTEST_INPUT_PATH)/$(1).in + $(TESTSUITE_WRAPPER) $(BASE_EXE_BLASTEST_PATH)/$(1).x < $(BLASTEST_INPUT_PATH)/$(1).in else @echo "Running $(1).x < '$(BLASTEST_INPUT_PATH)/$(1).in' (output to 'out.$(1)')" - @$(TESTSUITE_WRAPPER) $(BASE_OBJ_BLASTEST_PATH)/$(1).x < $(BLASTEST_INPUT_PATH)/$(1).in + @$(TESTSUITE_WRAPPER) $(BASE_EXE_BLASTEST_PATH)/$(1).x < $(BLASTEST_INPUT_PATH)/$(1).in endif endef @@ -916,7 +919,7 @@ testsuite: testsuite-run testsuite-bin: check-env $(TESTSUITE_BIN) # Object file rule. -$(BASE_OBJ_TESTSUITE_PATH)/%.o: $(TESTSUITE_SRC_PATH)/%.c +$(BASE_OBJ_TESTSUITE_PATH)/%.o: $(TESTSUITE_SRC_PATH)/%.c $(BLIS_H_FLAT) ifeq ($(ENABLE_VERBOSE),yes) $(CC) $(call get-user-cflags-for,$(CONFIG_NAME)) -c $< -o $@ else @@ -926,6 +929,7 @@ endif # Testsuite binary rule. $(TESTSUITE_BIN): $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) + @mkdir -p $(BASE_EXE_TESTSUITE_PATH) ifeq ($(ENABLE_VERBOSE),yes) $(LINKER) $(MK_TESTSUITE_OBJS) $(LIBBLIS_LINK) $(LDFLAGS) -o $@ else @@ -936,13 +940,13 @@ endif # A rule to run the testsuite using the normal input.* files. testsuite-run: testsuite-bin ifeq ($(ENABLE_VERBOSE),yes) - $(TESTSUITE_WRAPPER) ./$(TESTSUITE_BIN) -g $(TESTSUITE_CONF_GEN_PATH) \ + $(TESTSUITE_WRAPPER) $(TESTSUITE_BIN) -g $(TESTSUITE_CONF_GEN_PATH) \ -o $(TESTSUITE_CONF_OPS_PATH) \ > $(TESTSUITE_OUT_FILE) else @echo "Running $(TESTSUITE_BIN) with output redirected to '$(TESTSUITE_OUT_FILE)'" - @$(TESTSUITE_WRAPPER) ./$(TESTSUITE_BIN) -g $(TESTSUITE_CONF_GEN_PATH) \ + @$(TESTSUITE_WRAPPER) $(TESTSUITE_BIN) -g $(TESTSUITE_CONF_GEN_PATH) \ -o $(TESTSUITE_CONF_OPS_PATH) \ > $(TESTSUITE_OUT_FILE) endif @@ -1285,7 +1289,7 @@ ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(BLASTEST_F2C_OBJS) $(BLASTEST_DRV_OBJS) - $(RM_F) $(BLASTEST_F2C_LIB) - - $(RM_F) $(BLASTEST_DRV_BIN_PATHS) + - $(RM_RF) $(BASE_OBJ_BLASTEST_PATH)/{shared,static} - $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) else @echo "Removing object files from $(BASE_OBJ_BLASTEST_PATH)" @@ -1293,7 +1297,7 @@ else @echo "Removing libf2c.a from $(BASE_OBJ_BLASTEST_PATH)" @- $(RM_F) $(BLASTEST_F2C_LIB) @echo "Removing binaries from $(BASE_OBJ_BLASTEST_PATH)" - @- $(RM_F) $(BLASTEST_DRV_BIN_PATHS) + @- $(RM_RF) $(BASE_OBJ_BLASTEST_PATH)/{shared,static} @echo "Removing driver output files 'out.*'" @- $(RM_F) $(addprefix out.,$(BLASTEST_DRV_BASES)) endif # ENABLE_VERBOSE @@ -1328,13 +1332,13 @@ cleanblistesttop: ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) - $(RM_F) $(MK_TESTSUITE_OBJS) - - $(RM_F) $(TESTSUITE_BIN) + - $(RM_RF) $(BASE_OBJ_TESTSUITE_PATH)/{shared,static} - $(RM_F) $(TESTSUITE_OUT_FILE) else @echo "Removing object files from $(BASE_OBJ_TESTSUITE_PATH)" @- $(RM_F) $(MK_TESTSUITE_OBJS) @echo "Removing binary $(TESTSUITE_BIN)" - @- $(RM_F) $(TESTSUITE_BIN) + @- $(RM_RF) $(BASE_OBJ_TESTSUITE_PATH)/{shared,static} @echo "Removing $(TESTSUITE_OUT_FILE)" @- $(RM_F) $(TESTSUITE_OUT_FILE) endif # ENABLE_VERBOSE @@ -1344,13 +1348,13 @@ cleanblistestdir: ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - - $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) + - $(RM_RF) $(BASE_OBJ_TESTSUITE_PATH)/{shared,static} - $(MAKE) -C $(VEND_TESTCPP_DIR) clean else @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)" @- $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)" - @- $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) + @echo "Removing binary $(TESTSUITE_BIN)" + @- $(RM_RF) $(BASE_OBJ_TESTSUITE_PATH)/{shared,static} @$(MAKE) -C $(VEND_TESTCPP_DIR) clean endif # ENABLE_VERBOSE endif # IS_CONFIGURED diff --git a/addon/CMakeLists.txt b/addon/CMakeLists.txt new file mode 100644 index 0000000000..073a3fb75b --- /dev/null +++ b/addon/CMakeLists.txt @@ -0,0 +1,206 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# Writing a function that will be used to generate the required object +# libraries for the required addons. +function(generate_addon_targets addon_target) + # Collect all subdirectory paths that have at least one file with suffix in ADDON_C99_SUFS list. + get_filepaths_with_suffixes(LOCAL_SOURCE_C99_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${addon_target}" "${ADDON_C99_SUFS}") + # We want to break the files above in 2 categories, files in kernel directory and the rest. + # Only list files in kernel directory. + set(LOCAL_KERNEL_FILES_C99 ${LOCAL_SOURCE_FILES}) + list(FILTER LOCAL_KERNEL_FILES_C99 INCLUDE REGEX ${addon_target}/kernels/) + # All C99 files, except of the ones in kernels directory. + list(REMOVE_ITEM LOCAL_SOURCE_C99_FILES ${LOCAL_KERNEL_FILES_C99}) + + # Collect all subdirectory paths that have at least one file with suffix in ADDON_H99_SUFS list. + get_dirpaths_with_suffixes(CADDONINCFLAGS "${CMAKE_CURRENT_SOURCE_DIR}/${addon_target}" "${ADDON_H99_SUFS}") + + # Only generate the object library if there is at least one source file. + list(LENGTH LOCAL_SOURCE_C99_FILES size) + if(size GREATER 0) + # Create an object library using the source file list above. + add_library(${addon_target}_C99_ADDON + OBJECT + ${LOCAL_SOURCE_C99_FILES} + ) + # Include the corresponding make_defs.cmake that holds the required compiler options. + include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-addon-c99flags-for + target_compile_options(${addon_target}_C99_ADDON + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-addon-c99flags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${addon_target}_C99_ADDON + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-addon-c99flags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${addon_target}_C99_ADDON + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + # in get-addon-c99flags-for + ${CADDONINCFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${addon_target}_C99_ADDON PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${addon_target}_C99_ADDON PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_C99_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${addon_target}_C99_ADDON flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${addon_target}_C99_ADDON PROPERTIES FOLDER object-libs-targets) + endif() + + # Only generate the object library if there is at least one source file. + list(LENGTH LOCAL_KERNEL_FILES_C99 size) + if(size GREATER 0) + # Create an object library using the kernel source file list above. + add_library(${addon_target}_C99_KERNEL_ADDON + OBJECT + ${LOCAL_KERNEL_FILES_C99} + ) + # Include the corresponding make_defs.cmake that holds the required compiler options. + include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-addon-c99flags-for + target_compile_options(${addon_target}_C99_KERNEL_ADDON + PRIVATE + # load-var-for,CKOPTFLAGS + ${CKOPTFLAGS} + # load-var-for,CKVECFLAGS + ${CKVECFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-addon-kernel-c99flags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${addon_target}_C99_KERNEL_ADDON + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-addon-kernel-c99flags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${addon_target}_C99_KERNEL_ADDON + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + # in get-addon-kernel-c99flags-for + ${CADDONINCFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${addon_target}_C99_KERNEL_ADDON PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${addon_target}_C99_KERNEL_ADDON PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_C99_KERNEL_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${addon_target}_C99_KERNEL_ADDON flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${addon_target}_C99_KERNEL_ADDON PROPERTIES FOLDER object-libs-targets) + endif() + + # Collect all subdirectory paths that have at least one file with suffix in ADDON_CXX_SUFS list. + get_filepaths_with_suffixes(LOCAL_SOURCE_CXX_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${addon_target}" "${ADDON_CXX_SUFS}") + + # Only generate the object library if there is at least one source file. + list(LENGTH LOCAL_SOURCE_CXX_FILES size) + if(size GREATER 0) + # Create an object library using the source file list above. + add_library(${addon_target}_CXX_ADDON + OBJECT + ${LOCAL_SOURCE_CXX_FILES} + ) + + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-addon-cxxflags-for + target_compile_options(${addon_target}_CXX_ADDON + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cxxflags-for + ${CDBGFLAGS} + # get-noopt-cxxflags-for + ${CWARNFLAGS} + # get-noopt-cxxflags-for + ${CMISCFLAGS} + # get-noopt-cxxflags-for + ${CXXLANGFLAGS} + # in get-addon-cxxflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${addon_target}_CXX_ADDON + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-addon-cxxflags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${addon_target}_CXX_ADDON + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + # in get-addon-cxxflags-for + ${CADDONINCFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${addon_target}_CXX_ADDON PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${addon_target}_CXX_ADDON PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${addon_target}_CXX_ADDON PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${addon_target}_CXX_ADDON flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${addon_target}_CXX_ADDON PROPERTIES FOLDER object-libs-targets) + endif() +endfunction() + +# Generate targets for each of the addons. +foreach(ADDON ${ENABLE_ADDON}) + generate_addon_targets(${ADDON}) +endforeach() diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index 44de4ac658..027f895591 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,7 +43,7 @@ #include "lpgemm_post_ops.h" #include "lpgemm_kernels.h" #include "lpgemm_utils_kernels.h" -#include "lpgemm_packb_bf16.h" +#include "lpgemm_pack_bf16.h" #include "lpgemm_packb_s16.h" #include "lpgemm_packa.h" #include "lpgemm_packb.h" diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c index fd9d3be1f7..de709e8f90 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -85,12 +85,34 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) { + trans_t blis_trans; + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( trans, &blis_trans ); + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || - ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + ( k <= 0 ) || ( n <= 0 ) || ( bli_is_notrans( blis_trans ) && ( ldb < n ) ) || + ( bli_is_trans( blis_trans ) && ( ldb < k ) ) ) { return; // Error. } + inc_t rs_b, cs_b; + if( ( order == 'r') || ( order == 'R' ) ) + { + rs_b = bli_is_notrans( blis_trans ) ? ldb : 1; + cs_b = bli_is_notrans( blis_trans ) ? 1 : ldb; + } + else if ( ( order == 'c' ) || ( order == 'C' ) ) + { + rs_b = bli_is_notrans( blis_trans ) ? 1 : ldb; + cs_b = bli_is_notrans( blis_trans ) ? ldb : 1; + } + else + { + return; // Error + } + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) { @@ -117,7 +139,7 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); @@ -128,7 +150,8 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) // Create dummy original b obj; lpgemm_obj_t b; b.storage.aligned_buffer = ( void* )input_buf_addr; - b.rs = ldb; + b.rs = rs_b; + b.cs = cs_b; b.width = n; b.length = k; diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index 0e0f93e191..897facfbda 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -73,57 +74,42 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "bf16bf16f32obf16", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) - { - return; // Error. - } - - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; + bool is_row_major = ( ( order == 'r' ) || ( order == 'R' ) ); + bool is_column_major = ( ( order == 'c' ) || ( order == 'C' ) ); - bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); - bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + inc_t rs_a = lda; + inc_t cs_a = 1; - // Row major input expected with leading dimensions >= row stride. - if ( ( is_row_major == TRUE ) && - ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) + if ( bli_is_trans( blis_transa ) ) { - return; // Error. - } - // Column major input expected with leading dimensions >= column stride. - else if ( ( is_column_major == TRUE ) && - ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) - { - return; // Error. + rs_a = 1; + cs_a = lda; } - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if( bli_is_trans( blis_transb ) ) { - return; // Error. + rs_b = 1; + cs_b = ldb; } - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -133,6 +119,21 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + // Reorder is not supported for A matrix + if( ( is_row_major == TRUE ) && ( mtag_a == REORDERED ) ) + { + bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ ); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ( ( is_column_major == TRUE ) && ( ( mtag_b == REORDERED ) || ( mtag_a == REORDERED ) ) ) + { + bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ ); + return; + } + + // From 5-loop function point of view, // B matrix needs to be packed in a certain format in order to be loaded // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -147,30 +148,34 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) mtag_a = PACK; } - // Only unpacked A supported now. - if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if( ( is_row_major == TRUE ) && ( bli_is_trans(blis_transa ) ) ) { - return; // Error. + mtag_a = PACK; } - // Inputs swapped in column major, B becomes A from kernel point of view. - else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( bli_is_trans(blis_transb ) ) ) { - return; // Error. + mtag_b = PACK; } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); @@ -186,7 +191,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) ( float* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, BF16 ); } else @@ -199,7 +204,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) ( float* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, BF16 ); } #else @@ -214,7 +219,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) ( float* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, BF16 ); } else @@ -227,7 +232,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) ( float* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, BF16 ); } #endif diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index ca8b160220..0ca2602898 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -73,58 +74,42 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } +// check for validity of params. + AOCL_GEMM_CHECK + ( + "bf16bf16f32obf16", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) - { - return; // Error. - } - - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; + bool is_row_major = ( ( order == 'r' ) || ( order == 'R' ) ); + bool is_column_major = ( ( order == 'c' ) || ( order == 'C' ) ); - bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); - bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + // The strides are set assuming a row major kernel. + inc_t rs_a = lda; + inc_t cs_a = 1; - // Row major input expected with leading dimensions >= row stride. - if ( ( is_row_major == TRUE ) && - ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) + if ( bli_is_trans( blis_transa ) ) { - return; // Error. - } - // Column major input expected with leading dimensions >= column stride. - else if ( ( is_column_major == TRUE ) && - ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) - { - return; // Error. + rs_a = 1; + cs_a = lda; } - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + inc_t rs_b = ldb; + inc_t cs_b = 1; + + if( bli_is_trans( blis_transb ) ) { - return; // Error. + rs_b = 1; + cs_b = ldb; } - - // The strides are set assuming a row major kernel. - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; const inc_t rs_c = ldc; const inc_t cs_c = 1; @@ -134,12 +119,21 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); - if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) + // Reorder is not supported for A matrix + if( ( is_row_major == TRUE ) && ( mtag_a == REORDERED ) ) + { + bli_print_msg(" Reordering of A matrix is not supported in row major case.", __FILE__, __LINE__ ); + return; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + // Reorder is not supported for column major matrices. + else if ( ( is_column_major == TRUE ) && ( ( mtag_b == REORDERED ) || ( mtag_a == REORDERED ) ) ) { - // Reorder not supported with column major inputs. + bli_print_msg(" Reordering of column major matrices is not supported.", __FILE__, __LINE__ ); return; } + // From 5-loop function point of view // B matrix needs to be packed in a certain format in order to be loaded // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and @@ -154,30 +148,34 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) mtag_a = PACK; } - // Only unpacked A supported now. - if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + // From 5-loop function point of view, + // A matrix when in column major storage needs to be packed to row-major + // storage as kernel expects A matrix to be in row-major format. + if( ( is_row_major == TRUE ) && ( bli_is_trans(blis_transa ) ) ) { - return; // Error. + mtag_a = PACK; } - // Inputs swapped in column major, B becomes A from kernel point of view. - else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( bli_is_trans(blis_transb ) ) ) { - return; // Error. + mtag_b = PACK; } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); @@ -193,7 +191,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } else @@ -206,7 +204,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } #else @@ -221,7 +219,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } else @@ -234,7 +232,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } #endif diff --git a/addon/aocl_gemm/aocl_gemm_check.h b/addon/aocl_gemm/aocl_gemm_check.h new file mode 100644 index 0000000000..a49fb78007 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_check.h @@ -0,0 +1,104 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// yet to add validity check for postops +#define AOCL_GEMM_CHECK( op_str, \ + order, transa, transb, \ + m, n, k, \ + a, lda, mtag_a, \ + b, ldb, mtag_b, \ + c, ldc \ + ) \ +{ \ + int32_t info = 0; \ + bool col_stored, row_stored; \ + bool nota, notb, ta, tb; \ + \ + col_stored = ( order == 'c' ) || ( order == 'C' ); \ + row_stored = ( order == 'r' ) || ( order == 'R' ); \ + \ + nota = ( transa == 'n' ) || ( transa == 'N' ); \ + notb = ( transb == 'n' ) || ( transb == 'N' ); \ + \ + ta = ( transa == 't' ) || ( transa == 'T' ); \ + tb = ( transb == 't' ) || ( transb == 'T' ); \ + \ + if( ( order != 'r') && ( order != 'R' ) && ( order != 'c' ) && ( order != 'C' ) ) \ + info = 1; \ + else if( ( transa != 'n' ) && ( transa != 'N' ) && ( transa != 't' ) && ( transa != 'T' ) ) \ + info = 2; \ + else if( ( transb != 'n' ) && ( transb != 'N' ) && ( transb != 't' ) && ( transb != 'T' ) ) \ + info = 3; \ + else if ( m <= 0 ) \ + info = 4; \ + else if ( n <= 0 ) \ + info = 5; \ + else if ( k <= 0 ) \ + info = 6; \ + else if ( a == NULL ) \ + info = 8; \ + else if ( row_stored && ( ( nota && ( lda < k ) ) || ( ta && ( lda < m ) ) ) ) \ + info = 9; \ + else if ( col_stored && ( ( nota && ( lda < m ) ) || ( ta && ( lda < k ) ) ) ) \ + info = 9; \ + else if ( ( mtag_a != 'n' ) && ( mtag_a != 'N' ) && \ + ( mtag_a != 'p' ) && ( mtag_a != 'P' ) && \ + ( mtag_a != 'r' ) && ( mtag_a != 'R' ) ) \ + info = 10; \ + else if ( b == NULL ) \ + info = 11; \ + else if ( row_stored && ( ( notb && ( ldb < n ) ) || ( tb && ( ldb < k ) ) ) ) \ + info = 12; \ + else if ( col_stored && ( ( notb && ( ldb < k ) ) || ( tb && ( ldb < n ) ) ) ) \ + info = 12; \ + else if ( ( mtag_b != 'n' ) && ( mtag_b != 'N' ) && \ + ( mtag_b != 'p' ) && ( mtag_b != 'P' ) && \ + ( mtag_b != 'r' ) && ( mtag_b != 'R' ) ) \ + info = 13; \ + else if ( c == NULL ) \ + info = 15; \ + else if ( row_stored && ( ldc < n ) ) \ + info = 16; \ + else if ( col_stored && ( ldc < m ) ) \ + info = 16; \ + \ + if( info != 0 ) \ + { \ + char print_msg[ 100 ]; \ + \ + sprintf( print_msg, "** On entry to %6s, parameter number %2i had an illegal value", op_str, info); \ + bli_print_msg(print_msg, __FILE__, __LINE__); \ + return; \ + } \ +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index f3eed1aa65..107b651b71 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -64,13 +65,16 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(s), transa, transb, m, n, k,\ (void*)&alpha, lda, ldb, (void*)&beta, ldc); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "Invalid pointers provided for input parameters."); - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "f32f32f32of32", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); @@ -86,36 +90,8 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - - bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); - bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); - - // Row major input expected with leading dimensions >= row stride. - if ( ( is_row_major == TRUE ) && - ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) - { - return; // Error. - } - // Column major input expected with leading dimensions >= column stride. - else if ( ( is_column_major == TRUE ) && - ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "Invalid matrix dimensions."); - return; // Error. - } + bool is_row_major = ( ( order == 'r' ) || ( order == 'R' ) ); + bool is_column_major = ( ( order == 'c' ) || ( order == 'C' ) ); // The strides are set assuming a row major kernel. const inc_t rs_a = lda; @@ -168,17 +144,19 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 ); @@ -197,7 +175,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } else @@ -210,7 +188,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } #else @@ -229,7 +207,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } else @@ -242,7 +220,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, F32 ); } #endif diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c index 2116e418af..3b801ce0db 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index 718c0c3de2..7009cf1e2e 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,6 +42,8 @@ #define AOCL_GEMM_GET_REORDER_BUF_SIZE(LP_SFX) \ BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \ ( \ + const char order, \ + const char trans, \ const char mat_type, \ const dim_t k, \ const dim_t n \ @@ -60,6 +62,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16); #define AOCL_GEMM_REORDER(B_type,LP_SFX) \ BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \ ( \ + const char order, \ + const char trans, \ const char mat_type, \ const B_type* input_buf_addr, \ B_type* reorder_buf_addr, \ @@ -106,6 +110,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16); AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8); AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8); +AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8); AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16); AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32); AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8); diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h index 70084e741a..dbf869fae1 100644 --- a/addon/aocl_gemm/aocl_gemm_post_ops.h +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c index ca5ee12fc2..e9533536ab 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ((a == NULL) || (b == NULL) || (c == NULL)) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "s8s8s16os16", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(transa, &blis_transa); @@ -75,31 +81,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ((lda != k) || (ldb != n) || (ldc != n)) - { - return; // Error. - } - - // Check if dimensions are valid. - if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -125,22 +116,25 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) // Only unpacked A supported now. if (mtag_a != UNPACKED) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); @@ -153,7 +147,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S16 ); #else lpgemm_s8s8s16o16_thread_decorator @@ -164,7 +158,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S16 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c index 92a2663944..2d02416c6c 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c @@ -118,7 +118,7 @@ AOCL_GEMM_REORDER(int8_t,s8s8s16os16) // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c index a036612c82..8b30c51801 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ((a == NULL) || (b == NULL) || (c == NULL)) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "s8s8s16os8", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(transa, &blis_transa); @@ -75,31 +81,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ((lda != k) || (ldb != n) || (ldc != n)) - { - return; // Error. - } - - // Check if dimensions are valid. - if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -125,22 +116,25 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) // Only unpacked A supported now. if (mtag_a != UNPACKED) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); @@ -153,7 +147,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #else lpgemm_s8s8s16o16_thread_decorator @@ -164,7 +158,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c index b9ddecdba5..413de3f543 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "s8s8s32os32", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); @@ -75,32 +81,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -126,22 +116,25 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) // Only unpacked A supported now. if ( mtag_a != UNPACKED ) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); @@ -154,7 +147,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S32 ); #else lpgemm_s8s8s32o32_thread_decorator @@ -165,7 +158,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S32 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c index 4c41d8e184..ef4484aee5 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c @@ -118,7 +118,7 @@ AOCL_GEMM_REORDER(int8_t,s8s8s32os32) // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c index 7abc392a4e..5e7f3ec71c 100644 --- a/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "s8s8s32os8", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); @@ -75,32 +81,16 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -126,22 +116,25 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) // Only unpacked A supported now. if ( mtag_a != UNPACKED ) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); @@ -154,7 +147,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #else lpgemm_s8s8s32o32_thread_decorator @@ -165,7 +158,7 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index f851a283d5..c0614c643b 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ((a == NULL) || (b == NULL) || (c == NULL)) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "u8s8s16os16", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(transa, &blis_transa); @@ -75,31 +81,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ((lda != k) || (ldb != n) || (ldc != n)) - { - return; // Error. - } - - // Check if dimensions are valid. - if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -125,22 +116,25 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) // Only unpacked A supported now. if (mtag_a != UNPACKED) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); @@ -153,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S16 ); #else lpgemm_u8s8s16o16_thread_decorator @@ -164,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S16 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c index 98d8828f22..fd0c64203f 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -117,7 +117,7 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16) // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c index c4ca0ac572..e8d7b9d146 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ((a == NULL) || (b == NULL) || (c == NULL)) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "u8s8s16os8", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(transa, &blis_transa); @@ -75,31 +81,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ((lda != k) || (ldb != n) || (ldc != n)) - { - return; // Error. - } - - // Check if dimensions are valid. - if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -125,22 +116,25 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) // Only unpacked A supported now. if (mtag_a != UNPACKED) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); - bli_membrk_rntm_set_membrk(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); @@ -153,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #else lpgemm_u8s8s16o16_thread_decorator @@ -164,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c new file mode 100644 index 0000000000..fef861be1e --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16ou8.c @@ -0,0 +1,164 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform u8s8s16 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // check for validity of params. + AOCL_GEMM_CHECK + ( + "u8s8s16ou8", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); + return; // Error. + } + + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); + return; // Only row major supported. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + err_t err = lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order ) + ); + + if( err != BLIS_SUCCESS ) return; + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_pba_rntm_set_pba(&rntm_g); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, U8 + ); +#else + lpgemm_u8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, U8 + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 5580001d69..d89e6861c3 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "u8s8s32os32", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); @@ -75,32 +81,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -126,22 +116,25 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) // Only unpacked A supported now. if ( mtag_a != UNPACKED ) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); @@ -154,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S32 ); #else lpgemm_u8s8s32o32_thread_decorator @@ -165,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, FALSE + post_op_list, S32 ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c index 20f0b322d9..b62c294cc6 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -117,7 +117,7 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c index 55f062ee8f..6dab94b1fc 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_gemm_check.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" @@ -60,11 +61,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) // Set MC, NC, KC, NR, MR. aocl_lpgemm_init_global_cntx(); - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } + // check for validity of params. + AOCL_GEMM_CHECK + ( + "u8s8s32os8", + order, transa, transb, + m, n, k, + a, lda, mem_format_a, + b, ldb, mem_format_b, + c, ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); @@ -75,32 +81,16 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || ( blis_transb != BLIS_NO_TRANSPOSE ) ) { + bli_print_msg(" Transpose of matrices is not supported.", __FILE__, __LINE__ ); return; // Error. } - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + if ( ( order != 'r' ) && ( order != 'R' ) ) { + bli_print_msg(" Operation only supports row-major matrices.", __FILE__, __LINE__ ); return; // Only row major supported. } - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - return; // Error. - } - const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -126,22 +116,25 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) // Only unpacked A supported now. if ( mtag_a != UNPACKED ) { + bli_print_msg(" A matrix needs to be unpacked.", __FILE__, __LINE__ ); return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list + err_t err = lpgemm_translate_to_post_ops_list ( post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) + ( void* )c, ( void* )( &order ) ); + if( err != BLIS_SUCCESS ) return; + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); + bli_pba_rntm_set_pba( &rntm_g ); lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); @@ -154,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #else lpgemm_u8s8s32o32_thread_decorator @@ -165,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, lcntx_g, - post_op_list, TRUE + post_op_list, S8 ); #endif } diff --git a/addon/aocl_gemm/config/lpgemm_config.c b/addon/aocl_gemm/config/lpgemm_config.c index 0dad8c88a7..ca1020e324 100644 --- a/addon/aocl_gemm/config/lpgemm_config.c +++ b/addon/aocl_gemm/config/lpgemm_config.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,7 +37,7 @@ #include "lpgemm_func_map.h" #include "lpgemm_blksz_map.h" #include "lpgemm_kernels.h" -#include "lpgemm_packb_bf16.h" +#include "lpgemm_pack_bf16.h" #include "lpgemm_packb_s16.h" #include "lpgemm_packa.h" #include "lpgemm_packb.h" diff --git a/addon/aocl_gemm/config/lpgemm_config.h b/addon/aocl_gemm/config/lpgemm_config.h index 91863e416a..87020d0c3d 100644 --- a/addon/aocl_gemm/config/lpgemm_config.h +++ b/addon/aocl_gemm/config/lpgemm_config.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/config/lpgemm_func_map.h b/addon/aocl_gemm/config/lpgemm_func_map.h index 864f84aef2..875a211985 100644 --- a/addon/aocl_gemm/config/lpgemm_func_map.h +++ b/addon/aocl_gemm/config/lpgemm_func_map.h @@ -56,7 +56,7 @@ #define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \ PAMACRO(U8S8S16OS16, NULL) \ PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ - PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ PAMACRO(S8S8S16OS16, NULL) \ @@ -84,7 +84,7 @@ #define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \ PAMACRO(U8S8S16OS16, NULL) \ PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ - PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ PAMACRO(S8S8S16OS16, NULL) \ @@ -112,7 +112,7 @@ #define LPGEMM_PACKA_FUNC_MAP_AVX512 \ PAMACRO(U8S8S16OS16, NULL) \ PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ - PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(BF16BF16F32OF32, packa_mr16_bf16bf16f32of32) \ PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ PAMACRO(S8S8S16OS16, NULL) \ diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 1ece1db727..5a0201443b 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,12 +34,14 @@ #include "blis.h" #include "lpgemm_5loop_interface_apis.h" -#include "lpgemm_packb_bf16.h" +#include "lpgemm_pack_bf16.h" #include "lpgemm_kernels.h" #include "lpgemm_utils.h" #include "lpgemm_thrinfo_utils.h" #include "lpgemm_config.h" + + // Kernel function prototypes typedef void (*lpgemm_rowvar_bf16) ( @@ -73,6 +75,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) const int16_t* a_use = NULL; dim_t cs_a_use = cs_a; + dim_t rs_a_use = rs_a; dim_t a_block_stride = 0; const int16_t* b_use = NULL; @@ -86,8 +89,11 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) // Pack buffer for B. bfloat16* pack_b_buffer_bf16; + bfloat16* pack_a_buffer_bf16; mem_t mem_b = BLIS_MEM_INITIALIZER; + mem_t mem_a = BLIS_MEM_INITIALIZER; siz_t mem_b_size_req = 0; + siz_t mem_a_size_req = 0; dim_t packb_min_NR = 16; // Temporary buffer for C accumulation when downscaling is required. @@ -109,7 +115,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < F32 ) { post_ops_attr.buf_downscale = c; } @@ -149,12 +156,12 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ); } - if ( c_downscale == FALSE ) + if ( c_downscale == F32 ) { c_use_jc = c + jc; } // Temp accumulaton buffer for C allocation. - else if ( c_downscale == TRUE ) + else if ( c_downscale < F32 ) { // Buffer memory is only required if output needs to be // persisted across iterations of the pc/KC loop. @@ -167,7 +174,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) lpgemm_alloc_mem_panel ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_scale_c, rntm ); @@ -254,11 +261,11 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) if ( ( jc_packb_end > jc_packb_start ) && ( jc_packb_start < ( jc + nc0 ) ) ) { - ( ( packb_bf16 )lcntx->packb_fun_ptr ) + ( ( pack_bf16 )lcntx->packb_fun_ptr ) ( pack_b_buffer_bf16 + ( jc_packb_start * kc0_updated ), ( b + ( rs_b * pc ) + ( cs_b * jc ) + - ( cs_b * jc_packb_start ) ), rs_b, + ( cs_b * jc_packb_start ) ), rs_b, cs_b, ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); @@ -297,7 +304,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) // Only per thread C matrix is stored in temp buffer, so both // per thread jc and ic start should be normalized to zero. - if ( c_downscale == TRUE ) + if ( c_downscale < F32 ) { c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); } @@ -315,6 +322,31 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) // Non bf16 based kernel requires update to this code. cs_a_use = 2; a_block_stride = rs_a; + rs_a_use = rs_a; + } + else if ( mtag_a == PACK ) + { + + mem_a_size_req = sizeof( bfloat16 ) * mc0 * kc0; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + + pack_a_buffer_bf16 = + ( bfloat16* ) bli_mem_buffer( &mem_a ); + + ( ( pack_bf16 )lcntx->packa_fun_ptr ) + ( + pack_a_buffer_bf16, + ( a + ( rs_a * ic ) + ( cs_a * pc )), rs_a, cs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_bf16; + a_block_stride = rs_a_use; } for ( dim_t jr = 0; jr < nc0; jr += NR ) @@ -330,7 +362,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ( ( lpgemm_rowvar_bf16 )lcntx->kern_fun_ptr ) ( mc0, nr0, kc0, - a_use, rs_a, cs_a_use, a_block_stride, + a_use, rs_a_use, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, ( c_use_ic + jr ), rs_c_use, 1, alpha, beta0, @@ -360,15 +392,22 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { if ( bli_mem_is_alloc( &mem_b ) ) { - bli_membrk_release( rntm, &mem_b ); + bli_pba_release( rntm, &mem_b ); } } } - if ( c_downscale == TRUE ) + if( mtag_a == PACK ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_pba_release(rntm, &mem_a); + } + } + if ( c_downscale < F32 ) { if ( bli_mem_is_alloc( &mem_scale_c ) ) { - bli_membrk_release( rntm, &mem_scale_c ); + bli_pba_release( rntm, &mem_scale_c ); } } } diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c index b90d339664..99c17b909f 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,7 @@ #include "blis.h" #include "lpgemm_utils.h" #include "lpgemm_reorder_bf16.h" -#include "lpgemm_packb_bf16.h" +#include "lpgemm_pack_bf16.h" #include "lpgemm_config.h" #include "aocl_bf16_type.h" @@ -53,6 +53,7 @@ void reorderb_nr64_bf16bf16f32of32 // Extracting the matrix properties from the lpgemm object dim_t rs_b = b->rs; + dim_t cs_b = b->cs; dim_t n = b->width; dim_t k = b->length; @@ -148,14 +149,14 @@ void reorderb_nr64_bf16bf16f32of32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) - ( ( packb_bf16 )lcntx->packb_fun_ptr ) + ( ( pack_bf16 )lcntx->packb_fun_ptr ) ( - ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + - ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + - ( jc_cur_loop_rem * kc0_updated ) ), + ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ), ( ( ( bfloat16* )b->storage.aligned_buffer ) + - ( rs_b * pc ) + jc ), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ( rs_b * pc ) + (jc * cs_b)), + rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); } diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h index 42c8cb9ef6..d9fddedb6e 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 1864d78330..61e8cf8654 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -150,7 +150,8 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < F32 ) { post_ops_attr.buf_downscale = c; } @@ -395,7 +396,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) { if ( bli_mem_is_alloc( &mem_b ) ) { - bli_membrk_release( rntm, &mem_b ); + bli_pba_release( rntm, &mem_b ); } } } @@ -403,7 +404,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) { if ( bli_mem_is_alloc( &mem_a ) ) { - bli_membrk_release( rntm, &mem_a ); + bli_pba_release( rntm, &mem_a ); } } } diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index 62fc678faa..a0920edaf3 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -62,7 +62,7 @@ void lpgemm_rowvar_ ## LP_SFX \ lpgemm_thrinfo_t* thread, \ lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ - bool c_downscale \ + AOCL_STORAGE_TYPE c_downscale \ ) \ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index fffe14c0f8..92f5849c20 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,7 +55,7 @@ BLIS_INLINE void lpgemm_set_node_params post_op_node->next = NULL; } -void lpgemm_translate_to_post_ops_list +err_t lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, @@ -70,7 +70,7 @@ void lpgemm_translate_to_post_ops_list post_op_list, POST_OPS_DISABLE, NULL, NULL, NULL, NULL, FALSE ); - return; + return BLIS_SUCCESS; } if ( ( post_op_unparsed->seq_length > AOCL_MAX_POST_OPS ) ) @@ -80,7 +80,7 @@ void lpgemm_translate_to_post_ops_list post_op_list, POST_OPS_DISABLE, NULL, NULL, NULL, NULL, FALSE ); - return; //Error, seq length exceeds max post ops permitted. + return BLIS_SUCCESS; //Error, seq length exceeds max post ops permitted. } dim_t e_i = 0; //Multiple eltwise supported. @@ -110,6 +110,11 @@ void lpgemm_translate_to_post_ops_list tmp_code = POST_OPS_RELU; break; case PRELU: + if( ( post_op_unparsed->eltwise + e_i )->algo.alpha == NULL ) + { + bli_print_msg(" Post_op.alpha is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } tmp_code = POST_OPS_RELU_SCALE; break; case GELU_TANH: @@ -119,6 +124,12 @@ void lpgemm_translate_to_post_ops_list tmp_code = POST_OPS_GELU_ERF; break; case CLIP: + if( ( ( post_op_unparsed->eltwise + e_i )->algo.alpha == NULL ) || + ( ( post_op_unparsed->eltwise + e_i )->algo.beta == NULL ) ) + { + bli_print_msg(" Post_op.clip min or max value is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } tmp_code = POST_OPS_CLIP; break; default: @@ -137,6 +148,11 @@ void lpgemm_translate_to_post_ops_list } break; case BIAS: + if( post_op_unparsed->bias.bias == NULL ) + { + bli_print_msg(" Post_op.bias is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } lpgemm_set_node_params ( ( post_op_list + i ), POST_OPS_BIAS, @@ -145,6 +161,12 @@ void lpgemm_translate_to_post_ops_list ); break; case SCALE: + if( ( post_op_unparsed->sum.scale_factor == NULL ) || + ( post_op_unparsed->sum.zero_point == NULL ) ) + { + bli_print_msg(" Post_op.scale scale_factor or zero_point is NULL. Exiting..", __FILE__, __LINE__ ); + return BLIS_NULL_POINTER; + } lpgemm_set_node_params ( ( post_op_list + i ), POST_OPS_DOWNSCALE, @@ -163,4 +185,5 @@ void lpgemm_translate_to_post_ops_list ( post_op_list + i )->next = ( post_op_list + i + 1); } } + return BLIS_SUCCESS; } diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index 7509e57a39..ed1d3ed86b 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -70,12 +70,13 @@ typedef struct lpgemm_post_op_attr_t void* buf_downscale; bool is_first_k; bool is_last_k; + AOCL_STORAGE_TYPE c_stor_type; dim_t b_sum_offset; int32_t* b_col_sum_vec; int16_t* b_col_sum_vec_s16; } lpgemm_post_op_attr; -void lpgemm_translate_to_post_ops_list +err_t lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h index b700c03878..28f210a067 100644 --- a/addon/aocl_gemm/frame/lpgemm_types.h +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,6 +42,24 @@ typedef enum INT32 = 2 } AOCL_ARRAY_TYPE; +// Enum to denote the storage data type (output matrix). +// It is expected that the enum entries are in ascending order of +// storage data type size. +typedef enum +{ + S8 = 0, + U8 = 1, + S16 = 2, + U16 = 3, + BF16 = 4, + S32 = 5, + U32 = 6, + F32 = 7, + S64 = 8, + U64 = 9, + F64 = 10 +} AOCL_STORAGE_TYPE; + // Enum name template:A_mat_type ## B_mat_type ## Accumulate_type ## C_mat_type. typedef enum { diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c index 86ee194eb5..974ff4f3eb 100644 --- a/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c @@ -116,7 +116,8 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < S16 ) { post_ops_attr.buf_downscale = c; } @@ -156,12 +157,12 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) ); } - if ( c_downscale == FALSE ) + if ( c_downscale == S16 ) { c_use_jc = c + jc; } // Temp accumulaton buffer for C allocation. - else if ( c_downscale == TRUE ) + else if ( c_downscale < S16 ) { // Buffer memory is only required if output needs to be // persisted across iterations of the pc/KC loop. @@ -174,7 +175,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) lpgemm_alloc_mem_panel ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_scale_c, rntm ); @@ -329,7 +330,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) // Only per thread C matrix is stored in temp buffer, so both // per thread jc and ic start should be normalized to zero. - if ( c_downscale == TRUE ) + if ( c_downscale < S16 ) { c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); } @@ -388,15 +389,15 @@ LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) { if (bli_mem_is_alloc(&mem_b)) { - bli_membrk_release(rntm, &mem_b); + bli_pba_release(rntm, &mem_b); } } } - if ( c_downscale == TRUE ) + if ( c_downscale < S16 ) { if ( bli_mem_is_alloc( &mem_scale_c ) ) { - bli_membrk_release( rntm, &mem_scale_c ); + bli_pba_release( rntm, &mem_scale_c ); } } } diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c index 98b8081b51..21fa102fd4 100644 --- a/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c @@ -123,7 +123,8 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < S32 ) { post_ops_attr.buf_downscale = c; } @@ -163,12 +164,12 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) ); } - if ( c_downscale == FALSE ) + if ( c_downscale == S32 ) { c_use_jc = c + jc; } // Temp accumulaton buffer for C allocation. - else if ( c_downscale == TRUE ) + else if ( c_downscale < S32 ) { // Buffer memory is only required if output needs to be // persisted across iterations of the pc/KC loop. @@ -181,7 +182,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) lpgemm_alloc_mem_panel ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_scale_c, rntm ); @@ -335,7 +336,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) // Only per thread C matrix is stored in temp buffer, so both // per thread jc and ic start should be normalized to zero. - if ( c_downscale == TRUE ) + if ( c_downscale < S32 ) { c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); } @@ -426,7 +427,7 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) { if ( bli_mem_is_alloc( &mem_b ) ) { - bli_membrk_release( rntm, &mem_b ); + bli_pba_release( rntm, &mem_b ); } } } @@ -434,14 +435,14 @@ LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) { if ( bli_mem_is_alloc( &mem_a ) ) { - bli_membrk_release( rntm, &mem_a ); + bli_pba_release( rntm, &mem_a ); } } - if ( c_downscale == TRUE ) + if ( c_downscale < S32 ) { if ( bli_mem_is_alloc( &mem_scale_c ) ) { - bli_membrk_release( rntm, &mem_scale_c ); + bli_pba_release( rntm, &mem_scale_c ); } } } diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h index e91d0f8816..474d07ff2f 100644 --- a/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h @@ -123,7 +123,7 @@ BLIS_INLINE void lpgemm_alloc_mem_panel { if ( bli_mem_is_unalloc( mem ) ) { - bli_membrk_acquire_m + bli_pba_acquire_m ( rntm_l, size_req, @@ -136,8 +136,8 @@ BLIS_INLINE void lpgemm_alloc_mem_panel siz_t mem_size = bli_mem_size( mem ); if ( mem_size < size_req ) { - bli_membrk_release( rntm_l, mem ); - bli_membrk_acquire_m + bli_pba_release( rntm_l, mem ); + bli_pba_acquire_m ( rntm_l, size_req, diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 32615afc9e..ad0e7f10d5 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -268,7 +268,7 @@ BLIS_INLINE void lpgemm_adjust_ic_jc_ways } } -BLIS_INLINE void lpgemm_u8s8s16o16_get_threading +BLIS_INLINE void lpgemm_s16o16_get_threading ( dim_t* n_threads, dim_t* ic_ways, @@ -276,7 +276,8 @@ BLIS_INLINE void lpgemm_u8s8s16o16_get_threading dim_t m, dim_t n, dim_t k, - rntm_t* rntm_g + rntm_t* rntm_g, + AOCL_OPERATION_TYPE op_type ) { *n_threads = bli_rntm_num_threads( rntm_g ); @@ -295,19 +296,47 @@ BLIS_INLINE void lpgemm_u8s8s16o16_get_threading else if ( ( *n_threads ) > 1 ) { - dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type ); + dim_t mr_blks = ( m + MR - 1 ) / MR; + dim_t nr_blks = ( n + NR - 1 ) / NR; if ( n <= NR ) { - // If n is less than micro panel dimension, allocating all threads - // to ic resulted in gains. - ( *ic_ways ) = ( *n_threads ); + ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); ( *jc_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( m <= MR ) + { + ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *ic_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) + { + ( *ic_ways ) = mr_blks; + ( *jc_ways ) = nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks < ( *ic_ways ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( nr_blks < ( *jc_ways ) ) + { + ( *jc_ways ) = nr_blks; + dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); + ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } } } else @@ -320,7 +349,27 @@ BLIS_INLINE void lpgemm_u8s8s16o16_get_threading } } -BLIS_INLINE void lpgemm_u8s8s32o32_get_threading +BLIS_INLINE void lpgemm_u8s8s16o16_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + lpgemm_s16o16_get_threading + ( + n_threads, + ic_ways, jc_ways, + m, n, k, rntm_g, + U8S8S16OS16 + ); +} + +BLIS_INLINE void lpgemm_s8s8s16o16_get_threading ( dim_t* n_threads, dim_t* ic_ways, @@ -330,6 +379,27 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading dim_t k, rntm_t* rntm_g ) +{ + lpgemm_s16o16_get_threading + ( + n_threads, + ic_ways, jc_ways, + m, n, k, rntm_g, + S8S8S16OS16 + ); +} + +BLIS_INLINE void lpgemm_s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g, + AOCL_OPERATION_TYPE op_type + ) { *n_threads = bli_rntm_num_threads( rntm_g ); *jc_ways = bli_rntm_jc_ways( rntm_g ); @@ -347,26 +417,55 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading else if ( ( *n_threads ) > 1 ) { - dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); - dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( op_type ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( op_type ); + dim_t mr_blks = ( m + MR - 1 ) / MR; + dim_t nr_blks = ( n + NR - 1 ) / NR; if ( n <= NR ) { - // If n is less than micro panel dimension, allocating all threads - // to ic resulted in gains. - ( *ic_ways ) = ( *n_threads ); + ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); ( *jc_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( m <= MR ) + { + ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *ic_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - - lpgemm_pnl_wrk_heur_adjust_ic_jc_ways - ( - MR, NR, m, n, - n_threads, ic_ways, jc_ways - ); + if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) + { + ( *ic_ways ) = mr_blks; + ( *jc_ways ) = nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks < ( *ic_ways ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( nr_blks < ( *jc_ways ) ) + { + ( *jc_ways ) = nr_blks; + dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); + ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else + { + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } } } else @@ -379,6 +478,46 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading } } +BLIS_INLINE void lpgemm_u8s8s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + lpgemm_s32o32_get_threading + ( + n_threads, + ic_ways, jc_ways, + m, n, k, rntm_g, + U8S8S32OS32 + ); +} + +BLIS_INLINE void lpgemm_s8s8s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + lpgemm_s32o32_get_threading + ( + n_threads, + ic_ways, jc_ways, + m, n, k, rntm_g, + S8S8S32OS32 + ); +} + BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading ( dim_t* n_threads, @@ -408,24 +547,53 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + dim_t mr_blks = ( m + MR - 1 ) / MR; + dim_t nr_blks = ( n + NR - 1 ) / NR; if ( n <= NR ) { - // If n is less than micro panel dimension, allocating all threads - // to ic resulted in gains. - ( *ic_ways ) = ( *n_threads ); + ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); ( *jc_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( m <= MR ) + { + ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *ic_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - - lpgemm_pnl_wrk_heur_adjust_ic_jc_ways - ( - MR, NR, m, n, - n_threads, ic_ways, jc_ways - ); + if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) + { + ( *ic_ways ) = mr_blks; + ( *jc_ways ) = nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks < ( *ic_ways ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( nr_blks < ( *jc_ways ) ) + { + ( *jc_ways ) = nr_blks; + dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); + ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else + { + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } } } else @@ -485,91 +653,54 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading } else if ( ( *n_threads ) > 1 ) { - // If BLIS_NUM_THREADS are set, generate jc,ic from the same. - bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - - lpgemm_adjust_ic_jc_ways - ( - m, n, k, - MC, NC, KC, MR, NR, - n_threads, ic_ways, jc_ways, 5 - ); - } - else - { - // Setting all the values to 1 in case n_threads <= 1. This ensures - // the threading parameters are valid. - *n_threads = 1; - *jc_ways = 1; - *ic_ways = 1; - } - - // Native -> SUP path. - const dim_t m_ic = m / ( *ic_ways ); - const dim_t n_jc = n / ( *jc_ways ); - const dim_t page_size = bli_info_get_page_size(); - const dim_t page_size_b_floatx2 = - 2 * ( page_size / sizeof( float ) ); - - if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) ) - { - if ( ( k > page_size_b_floatx2 ) || - ( ( k <= page_size_b_floatx2 ) && - ( m_ic > MT_2 ) && ( n_jc >= NT ) ) ) - { - bli_rntm_set_pack_b( 1, rntm_g ); - bli_rntm_set_pack_a( 1, rntm_g ); - } - } -} - -BLIS_INLINE void lpgemm_s8s8s32o32_get_threading - ( - dim_t* n_threads, - dim_t* ic_ways, - dim_t* jc_ways, - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm_g - ) -{ - *n_threads = bli_rntm_num_threads( rntm_g ); - *jc_ways = bli_rntm_jc_ways( rntm_g ); - *ic_ways = bli_rntm_ic_ways( rntm_g ); - - if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) - { - // If BLIS_IC_NT or JC_NT are set. - // Default cases. - *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; - *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; - - *n_threads = ( *jc_ways ) * ( *ic_ways ); - } - else if ( ( *n_threads ) > 1 ) - { - - dim_t NR = lpgemm_get_block_size_NR_global_cntx( S8S8S32OS32 ); - dim_t MR = lpgemm_get_block_size_MR_global_cntx( S8S8S32OS32 ); + dim_t mr_blks = ( m + MR - 1 ) / MR; + dim_t nr_blks = ( n + NR - 1 ) / NR; if ( n <= NR ) { - // If n is less than micro panel dimension, allocating all threads - // to ic resulted in gains. - ( *ic_ways ) = ( *n_threads ); + ( *ic_ways ) = ( mr_blks < ( *n_threads ) ) ? mr_blks : ( *n_threads ); ( *jc_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( m <= MR ) + { + ( *jc_ways ) = ( nr_blks < ( *n_threads ) ) ? nr_blks : ( *n_threads ); + ( *ic_ways ) = 1; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); } else { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - - lpgemm_pnl_wrk_heur_adjust_ic_jc_ways - ( - MR, NR, m, n, - n_threads, ic_ways, jc_ways - ); + if ( ( mr_blks < ( *ic_ways ) ) && ( nr_blks < ( *jc_ways ) ) ) + { + ( *ic_ways ) = mr_blks; + ( *jc_ways ) = nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( mr_blks < ( *ic_ways ) ) + { + ( *ic_ways ) = mr_blks; + dim_t rem_jc_ways = ( dim_t )( ( *n_threads ) / ( *ic_ways ) ); + ( *jc_ways ) = ( rem_jc_ways < nr_blks ) ? rem_jc_ways : nr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else if ( nr_blks < ( *jc_ways ) ) + { + ( *jc_ways ) = nr_blks; + dim_t rem_ic_ways = ( dim_t )( ( *n_threads ) / ( *jc_ways ) ); + ( *ic_ways ) = ( rem_ic_ways < mr_blks ) ? rem_ic_ways : mr_blks; + ( *n_threads ) = ( *ic_ways ) * ( *jc_ways ); + } + else + { + lpgemm_adjust_ic_jc_ways + ( + m, n, k, + MC, NC, KC, MR, NR, + n_threads, ic_ways, jc_ways, 5 + ); + } } } else @@ -580,61 +711,25 @@ BLIS_INLINE void lpgemm_s8s8s32o32_get_threading *jc_ways = 1; *ic_ways = 1; } -} - -BLIS_INLINE void lpgemm_s8s8s16o16_get_threading - ( - dim_t* n_threads, - dim_t* ic_ways, - dim_t* jc_ways, - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm_g - ) -{ - *n_threads = bli_rntm_num_threads( rntm_g ); - *jc_ways = bli_rntm_jc_ways( rntm_g ); - *ic_ways = bli_rntm_ic_ways( rntm_g ); - if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) - { - // If BLIS_IC_NT or JC_NT are set. - // Default cases. - *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; - *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + // Native -> SUP path. + const dim_t m_ic = m / ( *ic_ways ); + const dim_t n_jc = n / ( *jc_ways ); + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_floatx2 = + 2 * ( page_size / sizeof( float ) ); - *n_threads = ( *jc_ways ) * ( *ic_ways ); - } - else if ( ( *n_threads ) > 1 ) + if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) ) { - - dim_t NR = lpgemm_get_block_size_NR_global_cntx( S8S8S16OS16 ); - - if ( n <= NR ) - { - // If n is less than micro panel dimension, allocating all threads - // to ic resulted in gains. - ( *ic_ways ) = ( *n_threads ); - ( *jc_ways ) = 1; - } - else + if (((k <= page_size_b_floatx2 ) && ( m_ic > MT_2 ) && ( n_jc >= NT ) ) || + ((bli_cpuid_is_avx512_supported() == FALSE ) && (k > page_size_b_floatx2))) { - // If BLIS_NUM_THREADS are set, generate jc,ic from the same. - bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + bli_rntm_set_pack_b( 1, rntm_g ); + bli_rntm_set_pack_a( 1, rntm_g ); } } - else - { - // Setting all the values to 1 in case n_threads <= 1. This ensures - // the threading parameters are valid. - *n_threads = 1; - *jc_ways = 1; - *ic_ways = 1; - } } - #define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ ( \ @@ -657,7 +752,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ rntm_t* rntm_g, \ lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ - bool c_downscale \ + AOCL_STORAGE_TYPE c_downscale \ ) \ { \ dim_t n_threads; \ @@ -676,14 +771,15 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ /* Set the packing block allocator field of the rntm. This will be * inherited by all of the child threads when they make local copies of * the rntm below.*/ \ - bli_membrk_rntm_set_membrk( rntm_g ); \ + bli_pba_rntm_set_pba( rntm_g ); \ \ thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \ thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \ + err_t bli_errors = BLIS_SUCCESS; \ \ if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ { \ - cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ) ); \ + cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ), &bli_errors ); \ } \ for ( dim_t i = 0; i < jc_ways; ++i ) \ { \ @@ -758,7 +854,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ rntm_t* rntm_g, \ lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ - bool c_downscale \ + AOCL_STORAGE_TYPE c_downscale \ ) \ { \ dim_t n_threads = 1; \ @@ -770,7 +866,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ /* Set the packing block allocator field of the rntm. This will be * inherited by all of the child threads when they make local copies of * the rntm below.*/ \ - bli_membrk_rntm_set_membrk( rntm_g ); \ + bli_pba_rntm_set_pba( rntm_g ); \ \ thrcomm_t static_lpgemm_comm; \ thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 80c657b230..4fd0a12bff 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,7 +63,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ rntm_t* rntm_g, \ lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ - bool c_downscale \ + AOCL_STORAGE_TYPE c_downscale \ ); \ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) @@ -97,7 +97,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ rntm_t* rntm_g, \ lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ - bool c_downscale \ + AOCL_STORAGE_TYPE c_downscale \ ); \ GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c index 2786117131..c0c1a29e7b 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h index 65647d9903..7a87bd6d56 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 5a03493a44..5e4740a952 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -113,7 +113,8 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < S16 ) { post_ops_attr.buf_downscale = c; } @@ -153,12 +154,12 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) ); } - if ( c_downscale == FALSE ) + if ( c_downscale == S16 ) { c_use_jc = c + jc; } // Temp accumulaton buffer for C allocation. - else if ( c_downscale == TRUE ) + else if ( c_downscale < S16 ) { // Buffer memory is only required if output needs to be // persisted across iterations of the pc/KC loop. @@ -171,7 +172,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) lpgemm_alloc_mem_panel ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_scale_c, rntm ); @@ -305,7 +306,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) // Only per thread C matrix is stored in temp buffer, so both // per thread jc and ic start should be normalized to zero. - if ( c_downscale == TRUE ) + if ( c_downscale < S16 ) { c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); } @@ -361,15 +362,15 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) { if (bli_mem_is_alloc(&mem_b)) { - bli_membrk_release(rntm, &mem_b); + bli_pba_release(rntm, &mem_b); } } } - if ( c_downscale == TRUE ) + if ( c_downscale < S16 ) { if ( bli_mem_is_alloc( &mem_scale_c ) ) { - bli_membrk_release( rntm, &mem_scale_c ); + bli_pba_release( rntm, &mem_scale_c ); } } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c index 224e0791ff..14dff21af4 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h index 232b02238d..58a5255637 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index feedda0212..29239803d6 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -122,7 +122,8 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) bool is_first_k = FALSE; lpgemm_post_op_attr post_ops_attr; - if ( c_downscale == TRUE ) + post_ops_attr.c_stor_type = c_downscale; + if ( c_downscale < S32 ) { post_ops_attr.buf_downscale = c; } @@ -162,12 +163,12 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ); } - if ( c_downscale == FALSE ) + if ( c_downscale == S32 ) { c_use_jc = c + jc; } // Temp accumulaton buffer for C allocation. - else if ( c_downscale == TRUE ) + else if ( c_downscale < S32 ) { // Buffer memory is only required if output needs to be // persisted across iterations of the pc/KC loop. @@ -180,7 +181,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) lpgemm_alloc_mem_panel ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + mem_scale_c_size_req, BLIS_BUFFER_FOR_GEN_USE, &mem_scale_c, rntm ); @@ -313,7 +314,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) // Only per thread C matrix is stored in temp buffer, so both // per thread jc and ic start should be normalized to zero. - if ( c_downscale == TRUE ) + if ( c_downscale < S32 ) { c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); } @@ -405,7 +406,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { if ( bli_mem_is_alloc( &mem_b ) ) { - bli_membrk_release( rntm, &mem_b ); + bli_pba_release( rntm, &mem_b ); } } } @@ -413,14 +414,14 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { if ( bli_mem_is_alloc( &mem_a ) ) { - bli_membrk_release( rntm, &mem_a ); + bli_pba_release( rntm, &mem_a ); } } - if ( c_downscale == TRUE ) + if ( c_downscale < S32 ) { if ( bli_mem_is_alloc( &mem_scale_c ) ) { - bli_membrk_release( rntm, &mem_scale_c ); + bli_pba_release( rntm, &mem_scale_c ); } } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h index 93acad6ac9..9c4e85cc05 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h @@ -123,7 +123,7 @@ BLIS_INLINE void lpgemm_alloc_mem_panel { if ( bli_mem_is_unalloc( mem ) ) { - bli_membrk_acquire_m + bli_pba_acquire_m ( rntm_l, size_req, @@ -136,8 +136,8 @@ BLIS_INLINE void lpgemm_alloc_mem_panel siz_t mem_size = bli_mem_size( mem ); if ( mem_size < size_req ) { - bli_membrk_release( rntm_l, mem ); - bli_membrk_acquire_m + bli_pba_release( rntm_l, mem ); + bli_pba_acquire_m ( rntm_l, size_req, diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h similarity index 83% rename from addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h rename to addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h index db5d31e513..1ceb833180 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_pack_bf16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,13 +47,14 @@ BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() return 16; } -typedef void (*packb_bf16) +typedef void (*pack_bf16) ( bfloat16*, const bfloat16*, const dim_t, const dim_t, const dim_t, + const dim_t, dim_t*, dim_t* ); @@ -62,11 +63,24 @@ void packb_nr64_bf16bf16f32of32 ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, - const dim_t ldb, + const dim_t rs_b, + const dim_t cs_b, const dim_t NC, const dim_t KC, - dim_t* rs_b, - dim_t* cs_b + dim_t* rs_p, + dim_t* cs_p ); + +void packa_mr16_bf16bf16f32of32 + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); #endif //BLIS_GEMM_BF16_PACKB diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index add69df94f..83132e8fbf 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h index a8f64c3fe0..1b3997ca3e 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h index 9b1c55046e..d0d507cbfb 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h index 1d69148e3c..2849cc8c33 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.c b/addon/gemmd/thread/bao_l3_decor_openmp.c index 1aca8de275..d77b0d8a7a 100644 --- a/addon/gemmd/thread/bao_l3_decor_openmp.c +++ b/addon/gemmd/thread/bao_l3_decor_openmp.c @@ -93,7 +93,7 @@ void bao_l3_thread_decorator // Query the thread's id from OpenMP. const dim_t tid = omp_get_thread_num(); - // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // Check for a somewhat obscure OpenMP thread-mismatch issue. // NOTE: This calls the same function used for the conventional/large // code path. bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); diff --git a/aocl_dtl/CMakeLists.txt b/aocl_dtl/CMakeLists.txt index 3985350ab2..5b69f0e116 100644 --- a/aocl_dtl/CMakeLists.txt +++ b/aocl_dtl/CMakeLists.txt @@ -1,10 +1,59 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. ## +##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. ## -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/aocldtl.c - ${CMAKE_CURRENT_SOURCE_DIR}/aocldtl_blis.c - ${CMAKE_CURRENT_SOURCE_DIR}/aoclfal.c - ${CMAKE_CURRENT_SOURCE_DIR}/aoclflist.c - ${CMAKE_CURRENT_SOURCE_DIR}/aoclos.c - ) +# Collect all subdirectory paths that have at least one file with suffix in AOCLDTL_SRC_SUFS list. +get_filepaths_with_suffixes(LOCAL_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR} "${AOCLDTL_SRC_SUFS}") + +# Create an object library using the source file list above. +add_library(AOCL_DTL + OBJECT + ${LOCAL_SOURCE_FILES} + ) + +# Include the corresponding make_defs.cmake that holds the required compiler options. +include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) +# Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. +# mimicing get-aocldtl-cflags-for +target_compile_options(AOCL_DTL + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-aocldtl-cflags-for + ${BUILD_SYMFLAGS} + ) +target_compile_definitions(AOCL_DTL + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-aocldtl-cflags-for + ${BUILD_CPPFLAGS} + # in get-aocldtl-cflags-for + ${CPPROCFLAGS} + ) +target_include_directories(AOCL_DTL + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) +if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(AOCL_DTL PRIVATE OpenMP::OpenMP_C) +elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(AOCL_DTL PRIVATE ${CTHREADFLAGS}) +endif() +if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(AOCL_DTL PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() +add_dependencies(AOCL_DTL flat-header) +# Put all those targets under object-libs-targets folder name so that they appear all together in IDE. +set_target_properties(AOCL_DTL PROPERTIES FOLDER object-libs-targets) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index a9b3db1786..3624f8c004 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -539,11 +539,11 @@ uint64 AOCL_DTL_get_time_spent(void) #ifdef AOCL_DTL_AUTO_TRACE_ENABLE /* - Disable intrumentation for these functions as they will also be - called from compiler generated instumation code to trace + Disable instrumentation for these functions as they will also be + called from compiler generated instrumentation code to trace function execution. - It needs to be part of declration in the C file so can't be + It needs to be part of declaration in the C file so can't be moved to header file. WARNING: These functions are automatically invoked. however any function diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 7f9934ed24..7800bb432d 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -5,7 +5,7 @@ * It provides defination for all macros to be * used by user to add debug/trace information. * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aocldtl_blis.c b/aocl_dtl/aocldtl_blis.c index c4de2bfcda..90be337f26 100755 --- a/aocl_dtl/aocldtl_blis.c +++ b/aocl_dtl/aocldtl_blis.c @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -92,6 +92,7 @@ void AOCL_DTL_log_gemm_sizes(int8 loglevel, } void AOCL_DTL_log_gemm_stats(int8 loglevel, + char dt_type, const f77_int m, const f77_int n, const f77_int k) @@ -99,14 +100,52 @@ void AOCL_DTL_log_gemm_stats(int8 loglevel, char buffer[256]; double flops = 2.0 * m * n * k; + if (dt_type == 'c' || dt_type == 'C' || dt_type == 'z' || dt_type == 'Z') + { + flops = 4.0 * flops; + } // Execution time is in micro seconds. Double execution_time = AOCL_DTL_get_time_spent(); - sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", - AOCL_get_requested_threads_count(), - execution_time/1000.0, - flops/(execution_time * 1e3)); + if (execution_time != 0.0) + sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", + AOCL_get_requested_threads_count(), + execution_time/1000.0, + flops/(execution_time * 1e3)); + else + sprintf(buffer, " nt=%ld %.3f ms", + AOCL_get_requested_threads_count(), + execution_time/1000.0); + + DTL_Trace(loglevel, TRACE_TYPE_RAW, NULL, NULL, 0, buffer); +} + +void AOCL_DTL_log_gemmt_stats(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int k) +{ + char buffer[256]; + + double flops = n * n * k; + if (dt_type == 'c' || dt_type == 'C' || dt_type == 'z' || dt_type == 'Z') + { + flops = 4.0 * flops; + } + + // Execution time is in micro seconds. + Double execution_time = AOCL_DTL_get_time_spent(); + + if (execution_time != 0.0) + sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", + AOCL_get_requested_threads_count(), + execution_time/1000.0, + flops/(execution_time * 1e3)); + else + sprintf(buffer, " nt=%ld %.3f ms", + AOCL_get_requested_threads_count(), + execution_time/1000.0); DTL_Trace(loglevel, TRACE_TYPE_RAW, NULL, NULL, 0, buffer); } @@ -131,17 +170,57 @@ void AOCL_DTL_log_trsm_sizes(int8 loglevel, double alpha_real = 0.0; double alpha_imag = 0.0; + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); //{S, D, C, Z} side, uplo, transa, diaga, m, n, lda, ldb, alpha_real, alpha_imag - sprintf(buffer, "%c %c %c %c %c %ld %ld %ld %ld %lf %lf\n", dt_type, + sprintf(buffer, "%c %c %c %c %c %ld %ld %ld %ld %lf %lf", dt_type, side, uploa, transa, diaga, (dim_t)m, (dim_t)n, (dim_t)lda, (dim_t)ldb, alpha_real, alpha_imag); + AOCL_DTL_START_PERF_TIMER(); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } +void AOCL_DTL_log_trsm_stats(int8 loglevel, + char dt_type, + f77_char side, + const f77_int m, + const f77_int n) +{ + char buffer[256]; + + double flops = 0.0; + if (side == 'L' || side =='l') + { + flops = 1.0 * m * n * m; + } + else + { + flops = 1.0 * m * n * n; + } + if (dt_type == 'c' || dt_type == 'C' || dt_type == 'z' || dt_type == 'Z') + { + flops = 4.0 * flops; + } + + // Execution time is in micro seconds. + Double execution_time = AOCL_DTL_get_time_spent(); + + if (execution_time != 0.0) + sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", + AOCL_get_requested_threads_count(), + execution_time/1000.0, + flops/(execution_time * 1e3)); + else + sprintf(buffer, " nt=%ld %.3f ms", + AOCL_get_requested_threads_count(), + execution_time/1000.0); + + DTL_Trace(loglevel, TRACE_TYPE_RAW, NULL, NULL, 0, buffer); +} + void AOCL_DTL_log_gemmt_sizes(int8 loglevel, char dt_type, char uplo, @@ -165,18 +244,20 @@ void AOCL_DTL_log_gemmt_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {S,D,C,Z} {triangC : l or u} {n k lda ldb ldc transa transb alpha_real alpha_imaginary // beta_real, beta_imaginary} - sprintf(buffer, "%c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf\n", + sprintf(buffer, "%c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf", dt_type, uplo, (dim_t)n, (dim_t)k, (dim_t)lda, (dim_t)ldb, (dim_t)ldc, transa, transb, alpha_real, alpha_imag, beta_real, beta_imag); + AOCL_DTL_START_PERF_TIMER(); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } @@ -639,12 +720,41 @@ void AOCL_DTL_log_nrm2_sizes(int8 loglevel, { char buffer[256]; // {S, D, C, Z} {n, incx} - sprintf(buffer, "%c %ld %ld\n", + sprintf(buffer, "%c %ld %ld", dt_type, (dim_t)n, (dim_t)incx); + AOCL_DTL_START_PERF_TIMER(); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } +void AOCL_DTL_log_nrm2_stats(int8 loglevel, + char dt_type, + const f77_int n) +{ + char buffer[256]; + + double flops = 2.0 * n; + if (dt_type == 'c' || dt_type == 'C' || dt_type == 'z' || dt_type == 'Z') + { + flops = 2.0 * flops; + } + + // Execution time is in micro seconds. + Double execution_time = AOCL_DTL_get_time_spent(); + + if (execution_time != 0.0) + sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", + AOCL_get_requested_threads_count(), + execution_time/1000.0, + flops/(execution_time * 1e3)); + else + sprintf(buffer, " nt=%ld %.3f ms", + AOCL_get_requested_threads_count(), + execution_time/1000.0); + + DTL_Trace(loglevel, TRACE_TYPE_RAW, NULL, NULL, 0, buffer); +} + //Level-2 void AOCL_DTL_log_syr2_sizes(int8 loglevel, char dt_type, diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h index 7b352f9d43..275ad0a484 100755 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -33,10 +33,17 @@ void AOCL_DTL_log_gemm_sizes(int8 loglevel, int line); void AOCL_DTL_log_gemm_stats(int8 loglevel, + char dt_type, const f77_int m, const f77_int n, const f77_int k); +void AOCL_DTL_log_trsm_stats(int8 loglevel, + char dt_type, + f77_char side, + const f77_int m, + const f77_int n); + void AOCL_DTL_log_trsm_sizes(int8 loglevel, char dt, f77_char side, @@ -68,6 +75,11 @@ void AOCL_DTL_log_gemmt_sizes(int8 loglevel, const char* function_name, int line); +void AOCL_DTL_log_gemmt_stats(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int k); + void AOCL_DTL_log_hemm_sizes(int8 loglevel, char dt_type, const f77_char side, @@ -243,6 +255,10 @@ void AOCL_DTL_log_nrm2_sizes( int8 loglevel, const char* function_name, int line); +void AOCL_DTL_log_nrm2_stats(int8 loglevel, + char dt_type, + const f77_int n); + void AOCL_DTL_log_amax_sizes ( int8 loglevel, char dt_type, const f77_int n, @@ -389,15 +405,23 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, \ __FILE__, __FUNCTION__, __LINE__); -#define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) \ +#define AOCL_DTL_LOG_GEMM_STATS(loglevel, dt_type, m, n, k) \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_stats(loglevel, dt_type, m, n, k); + +#define AOCL_DTL_LOG_GEMMT_STATS(loglevel, dt_type, n, k) \ if (gbIsLoggingEnabled) \ - AOCL_DTL_log_gemm_stats(loglevel, m, n, k); + AOCL_DTL_log_gemmt_stats(loglevel, dt_type, n, k); #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ if (gbIsLoggingEnabled) \ AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, \ __FILE__, __FUNCTION__, __LINE__); +#define AOCL_DTL_LOG_TRSM_STATS(loglevel, dt_type, side, m, n) \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsm_stats(loglevel, dt_type, side, m, n); + #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) \ if (gbIsLoggingEnabled) \ AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, \ @@ -460,6 +484,10 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, if (gbIsLoggingEnabled) \ AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); +#define AOCL_DTL_LOG_NRM2_STATS(loglevel, dt_type, n) \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_nrm2_stats(loglevel, dt_type, n); + #define AOCL_DTL_LOG_HEMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy) \ if (gbIsLoggingEnabled) \ AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ @@ -531,12 +559,16 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) -#define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) +#define AOCL_DTL_LOG_GEMM_STATS(loglevel, dt_type, m, n, k) #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) +#define AOCL_DTL_LOG_TRSM_STATS(loglevel, dt_type, side, m, n) + #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) +#define AOCL_DTL_LOG_GEMMT_STATS(loglevel, dt_type, n, k) + #define AOCL_DTL_LOG_HEMM_INPUTS(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc) #define AOCL_DTL_LOG_HERK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc) @@ -561,6 +593,8 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_NRM2_INPUTS(loglevel, dt_type, n, incx) +#define AOCL_DTL_LOG_NRM2_STATS(loglevel, dt_type, n) + #define AOCL_DTL_LOG_HEMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy) #define AOCL_DTL_LOG_HER2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 408f38c516..4aa1293fcf 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -5,7 +5,7 @@ * libaray, all debug features (except auto trace) * can be enabled/disabled in this file. * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aoclfal.c b/aocl_dtl/aoclfal.c index 1eadf99b49..e96a42cf7c 100644 --- a/aocl_dtl/aoclfal.c +++ b/aocl_dtl/aoclfal.c @@ -3,7 +3,7 @@ * * Description : Platform/os independed file handling API's * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aoclfal.h b/aocl_dtl/aoclfal.h index 401ed4c355..c37b699be9 100644 --- a/aocl_dtl/aoclfal.h +++ b/aocl_dtl/aoclfal.h @@ -4,7 +4,7 @@ * Description : Interfaces for platform/os independed file * handling API's * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aoclflist.c b/aocl_dtl/aoclflist.c index 5d44fdba87..5265cd97c5 100644 --- a/aocl_dtl/aoclflist.c +++ b/aocl_dtl/aoclflist.c @@ -5,10 +5,11 @@ * each thread. This is used to log the data * to correct file as per the current thread id. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ +#include "blis.h" #include "aocltpdef.h" #include "aocldtl.h" #include "aoclfal.h" @@ -63,7 +64,11 @@ AOCL_FLIST_Node * AOCL_FLIST_GetNode(AOCL_FLIST_Node *plist, AOCL_TID tid) { if (temp->fp == NULL) { +#ifdef BLIS_ENABLE_PTHREADS + AOCL_DEBUGPRINT("Could not get saved time stamp for thread = %ld", tid); +#else AOCL_DEBUGPRINT("Could not get saved time stamp for thread = %d", tid); +#endif } return temp; } @@ -92,7 +97,11 @@ AOCL_FAL_FILE *AOCL_FLIST_GetFile(AOCL_FLIST_Node *plist, AOCL_TID tid) { if (temp->fp == NULL) { +#ifdef BLIS_ENABLE_PTHREADS + AOCL_DEBUGPRINT("File associated with this thread id %ld does not exists or closed", tid); +#else AOCL_DEBUGPRINT("File associated with this thread id %d does not exists or closed", tid); +#endif } return temp->fp; } @@ -118,8 +127,11 @@ AOCL_FAL_FILE *AOCL_FLIST_AddFile(const int8 *pchFilePrefix, AOCL_FLIST_Node **p } /* We don't have exiting file, lets try to open new one */ +#ifdef BLIS_ENABLE_PTHREADS + sprintf(pchFileName, "P%d_T%lu_%s", AOCL_getpid(), tid, pchFilePrefix); +#else sprintf(pchFileName, "P%d_T%u_%s", AOCL_getpid(), tid, pchFilePrefix); - +#endif file = AOCL_FAL_Open(pchFileName, "wb"); if (file == NULL) { diff --git a/aocl_dtl/aoclflist.h b/aocl_dtl/aoclflist.h index a4e45ca328..caf11057f2 100644 --- a/aocl_dtl/aoclflist.h +++ b/aocl_dtl/aoclflist.h @@ -5,7 +5,7 @@ * each thread. This is used to log the deta * to correct file as per the current thread id. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/aocl_dtl/aoclos.c b/aocl_dtl/aoclos.c index 896b1c89b3..92d278cb2a 100644 --- a/aocl_dtl/aoclos.c +++ b/aocl_dtl/aoclos.c @@ -3,9 +3,10 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ +#include "blis.h" #include "aocltpdef.h" #include "aocldtl.h" #include "aoclfal.h" @@ -20,18 +21,18 @@ #endif // BLIS TODO: This is workaround to check if BLIS is built with -// openmp support. Ideally we dont' want any library +// openmp support. Ideally we don't want any library // specific code in dtl. #include #if defined(__linux__) /* - Disable intrumentation for these functions as they will also be - called from compiler generated instumation code to trace + Disable instrumentation for these functions as they will also be + called from compiler generated instrumentation code to trace function execution. - It needs to be part of declration in the C file so can't be + It needs to be part of declaration in the C file so can't be moved to header file. */ @@ -47,7 +48,10 @@ AOCL_TID AOCL_gettid(void) return omp_get_thread_num(); #else #ifdef BLIS_ENABLE_PTHREADS - return pthread_self(); + // pthread_self is not suitable for this purpose and may be replaced + // in a later release with something else. It returns a value of type + // pthread_t, which on linux is an unsigned long int. + return (AOCL_TID) pthread_self(); #else return 0; #endif @@ -89,7 +93,11 @@ AOCL_TID AOCL_gettid(void) return omp_get_thread_num(); #else #ifdef BLIS_ENABLE_PTHREADS - return pthread_self(); + // pthread_self is not suitable for this purpose and may be replaced + // in a later release with something else. It returns a value of type + // pthread_t, whose type may depend upon the operating system. On + // freeBSD it is a pointer to an empty struct. + return (AOCL_TID) pthread_self(); #else return 0; #endif diff --git a/aocl_dtl/aoclos.h b/aocl_dtl/aoclos.h index 57e0c24902..b22a329f5d 100644 --- a/aocl_dtl/aoclos.h +++ b/aocl_dtl/aoclos.h @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -11,7 +11,7 @@ #define _AOCL_OS_H_ #include "aocltpdef.h" -#include "malloc.h" +#include "stdlib.h" /* The OS Services function declaration */ diff --git a/aocl_dtl/aocltpdef.h b/aocl_dtl/aocltpdef.h index d842fffbac..8551dbe2cd 100644 --- a/aocl_dtl/aocltpdef.h +++ b/aocl_dtl/aocltpdef.h @@ -4,7 +4,7 @@ * * Description : Abstraction for various datatypes used by DTL. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ #ifndef AOCL_TYPEDEF_H_ @@ -35,8 +35,11 @@ typedef signed long int int32; typedef short int int16; typedef Void *AOCL_HANDLE; +#ifdef BLIS_ENABLE_PTHREADS +typedef long int AOCL_TID; +#else typedef pid_t AOCL_TID; - +#endif #endif /*AOCL_TYPEDEF_H_ */ /* --------------- End of aocltpdef.h ----------------- */ diff --git a/aocl_dtl/etrace_decoder.py b/aocl_dtl/etrace_decoder.py index 1a24f00cc3..5465076ad8 100755 --- a/aocl_dtl/etrace_decoder.py +++ b/aocl_dtl/etrace_decoder.py @@ -7,7 +7,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/aocl_dtl/test_dtl.c b/aocl_dtl/test_dtl.c index 08ff3296c3..05ab292d8e 100644 --- a/aocl_dtl/test_dtl.c +++ b/aocl_dtl/test_dtl.c @@ -3,7 +3,7 @@ * * Description : Unit test cases for dtl. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index 00d01fdd21..4c6fed1140 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -61,6 +61,13 @@ if(ENABLE_OPENMP) endif() target_link_libraries(BenchGer optimized "${LIB_NAME}.lib") +add_executable(BenchNrm2 bench_nrm2.c) +target_link_libraries(BenchNrm2 debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchNrm2 OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchNrm2 optimized "${LIB_NAME}.lib") + add_executable(BenchScalv bench_scalv.c) target_link_libraries(BenchScalv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) diff --git a/bench/Makefile b/bench/Makefile index 751f7129a5..cc1b7297dc 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -6,7 +6,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -193,7 +193,8 @@ blis: \ bench_amaxv_blis.x \ bench_copyv_blis.x \ bench_swapv_blis.x \ - bench_axpbyv_blis.x + bench_axpbyv_blis.x \ + bench_gemm_pack_compute_blis.x openblas: \ bench_gemm_openblas.x \ @@ -240,7 +241,8 @@ mkl: \ bench_amaxv_mkl.x \ bench_copyv_mkl.x \ bench_swapv_mkl.x \ - bench_axpbyv_mkl.x + bench_axpbyv_mkl.x \ + bench_gemm_pack_compute_mkl.x # --Object file rules -- diff --git a/bench/bench_amaxv.c b/bench/bench_amaxv.c index eb37319b6f..c4df0cd4d7 100644 --- a/bench/bench_amaxv.c +++ b/bench/bench_amaxv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index 9034a0d550..fbde59de5a 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,979 +1,3782 @@ -u r p 480 20 2050 2050 20 20 -u r p 481 20 2050 2050 20 20 -u r p 482 20 2050 2050 20 20 -u r p 483 20 2050 2050 20 20 -u r R 484 20 2050 2050 20 20 -u r R 485 20 2050 2050 20 20 -u r R 480 39 2050 2050 39 39 -u r R 481 39 2050 2050 39 39 -u r R 482 39 2050 2050 39 39 -u r R 483 39 2050 2050 39 39 -u r R 484 39 2050 2050 39 39 -u r p 485 39 2050 2050 39 39 -u r p 480 50 2050 2050 50 50 -u r p 481 50 2050 2050 50 50 -u r p 482 50 2050 2050 50 50 -u r p 483 50 2050 2050 50 50 -u r p 484 50 2050 2050 50 50 -u r p 485 50 2050 2050 50 50 -u r R 480 1108 2050 2050 1108 1108 -u r R 481 1108 2050 2050 1108 1108 -u r R 482 1108 2050 2050 1108 1108 -u r R 483 1108 2050 2050 1108 1108 -u r R 484 1108 2050 2050 1108 1108 -u r R 485 1108 2050 2050 1108 1108 -u r R 480 1127 2050 2050 1127 1127 -u r R 481 1127 2050 2050 1127 1127 -u r R 482 1127 2050 2050 1127 1127 -u r R 483 1127 2050 2050 1127 1127 -u r p 484 1127 2050 2050 1127 1127 -u r p 485 1127 2050 2050 1127 1127 -u r p 480 1138 2050 2050 1138 1138 -u r p 481 1138 2050 2050 1138 1138 -u r p 482 1138 2050 2050 1138 1138 -u r p 483 1138 2050 2050 1138 1138 -u r p 484 1138 2050 2050 1138 1138 -u r p 485 1138 2050 2050 1138 1138 -u r p 1 1 3 3 1 1 -u r p 1 9 3 3 9 9 -u r p 1 2048 3 3 2048 2048 -u r p 1 2048 5192 5192 2048 2048 -u r p 9 1 3 3 1 1 -u r p 576 1 3500 3500 1 1 -u r p 1 1 1 1 1 1 -u r p 102 1088 1024 1024 1088 1088 -u r p 102 2048 1024 1024 2048 2048 -u r p 485 656 1024 1024 656 656 -u r p 483 656 1024 1024 656 656 -u r p 81 128 3 3 128 128 -u r p 1022 512 515 515 512 512 -u r p 74 512 515 515 512 512 -u r p 253 2048 515 515 2048 2048 -u r p 8192 1040 515 515 1040 1040 -u r p 10 1029 515 515 1029 1029 -u r p 24 1040 2050 2050 1040 1040 -u r p 1024 1029 2050 2050 1029 1029 -u r p 480 660 2050 2050 660 660 -u r p 481 660 2050 2050 660 660 -u r p 482 660 2050 2050 660 660 -u r p 483 660 2050 2050 660 660 -u r p 484 660 2050 2050 660 660 -u r p 485 660 2050 2050 660 660 -u r p 480 679 2050 2050 679 679 -u r p 481 679 2050 2050 679 679 -u r p 482 679 2050 2050 679 679 -u r p 483 679 2050 2050 679 679 -u r p 484 679 2050 2050 679 679 -u r p 485 679 2050 2050 679 679 -u r p 480 690 2050 2050 690 690 -u r p 481 690 2050 2050 690 690 -u r p 482 690 2050 2050 690 690 -u r p 483 690 2050 2050 690 690 -u r p 484 690 2050 2050 690 690 -u r p 485 690 2050 2050 690 690 -u r p 480 660 2048 2048 660 660 -u r p 481 660 2048 2048 660 660 -u r p 482 660 2048 2048 660 660 -u r p 483 660 2048 2048 660 660 -u r p 484 660 2048 2048 660 660 -u r p 485 660 2048 2048 660 660 -u r p 480 679 2048 2048 679 679 -u r p 481 679 2048 2048 679 679 -u r p 482 679 2048 2048 679 679 -u r p 483 679 2048 2048 679 679 -u r p 484 679 2048 2048 679 679 -u r p 485 679 2048 2048 679 679 -u r p 480 690 2048 2048 690 690 -u r p 481 690 2048 2048 690 690 -u r p 482 690 2048 2048 690 690 -u r p 483 690 2048 2048 690 690 -u r p 484 690 2048 2048 690 690 -u r p 485 690 2048 2048 690 690 -u r p 480 656 1024 1024 656 656 -u r p 480 128 3 3 128 128 -u r p 1024 512 515 515 512 512 -u r p 1024 2048 1024 1024 2048 2048 -u r p 1024 2048 515 515 2048 2048 -u r p 1024 1040 515 515 1040 1040 -u r p 5 1029 515 515 1029 1029 -u r p 1024 1029 515 515 1029 1029 -u r p 1024 1040 2050 2050 1040 1040 -u r p 1029 1029 2050 2050 1029 1029 -u r R 480 646 2050 2050 646 646 -u r R 481 646 2050 2050 646 646 -u r R 482 646 2050 2050 646 646 -u r R 483 646 2050 2050 646 646 -u r R 484 646 2050 2050 646 646 -u r R 485 646 2050 2050 646 646 -u r R 481 656 2050 2050 656 656 -u r R 482 656 2050 2050 656 656 -u r R 483 656 2050 2050 656 656 -u r R 484 656 2050 2050 656 656 -u r p 485 656 2050 2050 656 656 -u r p 480 672 2050 2050 672 672 -u r p 481 672 2050 2050 672 672 -u r p 482 672 2050 2050 672 672 -u r p 483 672 2050 2050 672 672 -u r p 484 672 2050 2050 672 672 -u r p 485 672 2050 2050 672 672 -u r p 480 688 2050 2050 688 688 -u r p 481 688 2050 2050 688 688 -u r r 482 688 2050 2050 688 688 -u r r 483 688 2050 2050 688 688 -u r r 484 688 2050 2050 688 688 -u r r 485 688 2050 2050 688 688 -u r r 1024 512 64 64 512 512 -u r r 16 256 512 512 256 256 -u r r 480 640 512 512 640 640 -u r r 64 768 512 512 768 768 -u r r 128 128 128 128 128 128 -u r r 1024 64 512 512 64 64 -u r r 1024 256 32 32 256 256 -u r r 1024 512 64 64 512 512 -u r r 480 640 512 512 640 640 -u r p 1024 32 256 256 32 32 -u r P 1024 64 512 512 64 64 -u r P 64 800 320 320 800 800 -u r P 64 768 512 512 768 768 -u r P 16 256 512 512 256 256 -u r P 128 128 128 128 128 128 -u r P 256 512 256 256 512 512 -u r P 1024 1024 1024 1024 1024 1024 -u r P 480 640 1024 1024 640 640 -u r P 480 640 256 256 640 640 -u r P 8 64 32 32 64 64 -u r P 9 64 32 32 64 64 -u r P 10 128 64 64 128 128 -u r P 8 8 8 8 8 8 -u r P 12 12 12 12 12 12 -u r P 25 25 25 25 25 25 -u r P 25 25 20 20 25 25 -u r r 4096 256 5 5 256 256 -u r r 3000 256 128 128 256 256 -u r r 4096 1024 512 512 1024 1024 -u r r 144 256 5 5 256 256 -u r r 144 256 128 128 256 256 -u r r 144 1024 512 512 1024 1024 -u r r 480 688 256 256 688 688 -u r r 480 640 512 512 640 640 -u r r 480 640 1024 1024 640 640 -u r r 64 800 320 320 800 800 -u r r 64 768 512 512 768 768 -u r r 16 256 512 512 256 256 -u r r 128 128 128 128 128 128 -u r r 256 512 256 256 512 512 -u r r 1024 1024 1024 1024 1024 1024 -u r r 1024 32 256 256 32 32 -u r r 1024 64 512 512 64 64 -u r r 1024 256 32 32 256 256 -u r r 1024 512 64 64 512 512 -u r r 512 32 256 256 32 32 -u r r 512 768 512 512 768 768 -u r r 512 256 32 32 256 256 -u r r 512 512 64 64 512 512 -u r r 512 256 768 768 256 256 -u r r 768 768 1024 1024 768 768 -u r r 768 768 768 768 768 768 -u r r 2048 2048 2048 2048 2048 2048 -u r r 4096 4096 4096 4096 4096 4096 -f c p 2482 1127 2050 2482 2050 2482 -f c p 2483 1127 2050 2483 2050 2483 -f c p 2484 1127 2050 2484 2050 2484 -f c p 2485 1127 2050 2485 2050 2485 -f c p 480 1138 2050 480 2050 480 -f c p 481 1138 2050 481 2050 481 -f c p 482 1138 2050 482 2050 482 -f c p 483 1138 2050 483 2050 483 -f c p 484 1138 2050 484 2050 484 -f c p 485 1138 2050 485 2050 485 -f c p 1 1 3 1 3 1 -f c p 1 9 3 1 3 1 -f c p 1 2048 3 1 3 1 -f c p 1 2048 5192 1 5192 1 -f c p 9 1 3 9 3 9 -f c p 576 1 3500 576 3500 576 -f c p 1 1 1 1 1 1 -f c p 102 1088 1024 102 1024 102 -b r r 480 20 2050 2050 20 20 -b r r 481 20 2050 2050 20 20 -b r r 482 20 2050 2050 20 20 -b r p 483 20 2050 2050 20 20 -b r R 484 20 2050 2050 20 20 -b r R 485 20 2050 2050 20 20 -b r R 480 39 2050 2050 39 39 -b r R 481 39 2050 2050 39 39 -b r R 482 39 2050 2050 39 39 -b r R 483 39 2050 2050 39 39 -b r R 484 39 2050 2050 39 39 -b r p 485 39 2050 2050 39 39 -b r p 480 50 2050 2050 50 50 -b r p 481 50 2050 2050 50 50 -b r p 482 50 2050 2050 50 50 -b r p 483 50 2050 2050 50 50 -b r p 484 50 2050 2050 50 50 -b r p 485 50 2050 2050 50 50 -b r R 480 1108 2050 2050 1108 1108 -b r R 481 1108 2050 2050 1108 1108 -b r R 482 1108 2050 2050 1108 1108 -b r R 483 1108 2050 2050 1108 1108 -b r R 484 1108 2050 2050 1108 1108 -b r R 485 1108 2050 2050 1108 1108 -b r R 480 1127 2050 2050 1127 1127 -b r R 481 1127 2050 2050 1127 1127 -b r R 482 1127 2050 2050 1127 1127 -b r R 483 1127 2050 2050 1127 1127 -b r p 484 1127 2050 2050 1127 1127 -b r p 485 1127 2050 2050 1127 1127 -b r p 480 1138 2050 2050 1138 1138 -b r p 481 1138 2050 2050 1138 1138 -b r p 482 1138 2050 2050 1138 1138 -b r p 483 1138 2050 2050 1138 1138 -b r p 484 1138 2050 2050 1138 1138 -b r p 485 1138 2050 2050 1138 1138 -b r p 1 1 3 3 1 1 -b r p 1 9 3 3 9 9 -b r p 1 2048 3 3 2048 2048 -b r p 1 2048 5192 5192 2048 2048 -b r p 9 1 3 3 1 1 -b r p 576 1 3500 3500 1 1 -b r p 1 1 1 1 1 1 -b r p 102 1088 1024 1024 1088 1088 -b r p 102 2048 1024 1024 2048 2048 -b r p 485 656 1024 1024 656 656 -b r p 483 656 1024 1024 656 656 -b r p 81 128 3 3 128 128 -b r p 1022 512 515 515 512 512 -b r p 74 512 515 515 512 512 -b r p 253 2048 515 515 2048 2048 -b r p 8192 1040 515 515 1040 1040 -b r p 10 1029 515 515 1029 1029 -b r p 24 1040 2050 2050 1040 1040 -b r p 1024 1029 2050 2050 1029 1029 -b r p 480 660 2050 2050 660 660 -b r p 481 660 2050 2050 660 660 -b r p 482 660 2050 2050 660 660 -b r p 483 660 2050 2050 660 660 -b r p 484 660 2050 2050 660 660 -b r p 485 660 2050 2050 660 660 -b r p 480 679 2050 2050 679 679 -b r p 481 679 2050 2050 679 679 -b r p 482 679 2050 2050 679 679 -b r p 483 679 2050 2050 679 679 -b r p 484 679 2050 2050 679 679 -b r p 485 679 2050 2050 679 679 -b r p 480 690 2050 2050 690 690 -b r p 481 690 2050 2050 690 690 -b r p 482 690 2050 2050 690 690 -b r p 483 690 2050 2050 690 690 -b r p 484 690 2050 2050 690 690 -b r p 485 690 2050 2050 690 690 -b r p 480 660 2048 2048 660 660 -b r p 481 660 2048 2048 660 660 -b r p 482 660 2048 2048 660 660 -b r p 483 660 2048 2048 660 660 -b r p 484 660 2048 2048 660 660 -b r p 485 660 2048 2048 660 660 -b r p 480 679 2048 2048 679 679 -b r p 481 679 2048 2048 679 679 -b r p 482 679 2048 2048 679 679 -b r p 483 679 2048 2048 679 679 -b r p 484 679 2048 2048 679 679 -b r p 485 679 2048 2048 679 679 -b r p 480 690 2048 2048 690 690 -b r p 481 690 2048 2048 690 690 -b r p 482 690 2048 2048 690 690 -b r p 483 690 2048 2048 690 690 -b r p 484 690 2048 2048 690 690 -b r p 485 690 2048 2048 690 690 -b r p 480 656 1024 1024 656 656 -b r p 480 128 3 3 128 128 -b r p 1024 512 515 515 512 512 -b r p 1024 2048 1024 1024 2048 2048 -b r p 1024 2048 515 515 2048 2048 -b r p 1024 1040 515 515 1040 1040 -b r p 5 1029 515 515 1029 1029 -b r p 1024 1029 515 515 1029 1029 -b r p 1024 1040 2050 2050 1040 1040 -b r p 1029 1029 2050 2050 1029 1029 -b r R 480 646 2050 2050 646 646 -b r R 481 646 2050 2050 646 646 -b r R 482 646 2050 2050 646 646 -b r R 483 646 2050 2050 646 646 -b r R 484 646 2050 2050 646 646 -b r R 485 646 2050 2050 646 646 -b r R 481 656 2050 2050 656 656 -b r R 482 656 2050 2050 656 656 -b r R 483 656 2050 2050 656 656 -b r R 484 656 2050 2050 656 656 -b r p 485 656 2050 2050 656 656 -b r p 480 672 2050 2050 672 672 -b r p 481 672 2050 2050 672 672 -b r p 482 672 2050 2050 672 672 -b r p 483 672 2050 2050 672 672 -b r p 484 672 2050 2050 672 672 -b r p 485 672 2050 2050 672 672 -b r p 480 688 2050 2050 688 688 -b r p 481 688 2050 2050 688 688 -b r r 482 688 2050 2050 688 688 -b r r 483 688 2050 2050 688 688 -b r r 484 688 2050 2050 688 688 -b r r 485 688 2050 2050 688 688 -b r r 1024 512 64 64 512 512 -b r r 16 256 512 512 256 256 -b r r 480 640 512 512 640 640 -b r r 64 768 512 512 768 768 -b r r 128 128 128 128 128 128 -b r r 1024 64 512 512 64 64 -b r r 1024 256 32 32 256 256 -b r r 1024 512 64 64 512 512 -b r r 480 640 512 512 640 640 -b r p 1024 32 256 256 32 32 -b r P 1024 64 512 512 64 64 -b r P 64 800 320 320 800 800 -b r P 64 768 512 512 768 768 -b r P 16 256 512 512 256 256 -b r P 128 128 128 128 128 128 -b r P 256 512 256 256 512 512 -b r P 1024 1024 1024 1024 1024 1024 -b r P 480 640 1024 1024 640 640 -b r P 480 640 256 256 640 640 -b r P 8 64 32 32 64 64 -b r P 9 64 32 32 64 64 -b r P 10 128 64 64 128 128 -b r P 8 8 8 8 8 8 -b r P 12 12 12 12 12 12 -b r P 25 25 25 25 25 25 -b r P 25 25 20 20 25 25 -b c p 485 39 2050 485 2050 485 -b c p 480 50 2050 480 2050 480 -b c p 481 50 2050 481 2050 481 -b c p 482 50 2050 482 2050 482 -b c p 483 50 2050 483 2050 483 -b c p 484 50 2050 484 2050 484 -b c p 485 50 2050 485 2050 485 -b c p 484 1127 2050 484 2050 484 -b c p 485 1127 2050 485 2050 485 -b c p 480 1138 2050 480 2050 480 -b c p 481 1138 2050 481 2050 481 -b c p 482 1138 2050 482 2050 482 -b c p 483 1138 2050 483 2050 483 -b c p 484 1138 2050 484 2050 484 -b c p 485 1138 2050 485 2050 485 -b c p 1 1 3 1 3 1 -b c p 1 9 3 1 3 1 -b c p 1 2048 3 1 3 1 -b c p 1 2048 5192 1 5192 1 -b c p 9 1 3 9 3 9 -b c p 576 1 3500 576 3500 576 -b c p 1 1 1 1 1 1 -b c p 102 1088 1024 102 1024 102 -b c p 102 2048 1024 102 1024 102 -b c p 485 656 1024 485 1024 485 -b c p 483 656 1024 483 1024 483 -b c p 81 128 3 81 3 81 -b c p 1022 512 515 1022 515 1022 -b c p 74 512 515 74 515 74 -b c p 253 2048 515 253 515 253 -b c p 8192 1040 515 8192 515 8192 -b c p 10 1029 515 10 515 10 -b c p 24 1040 2050 24 2050 24 -b c p 1024 1029 2050 1024 2050 1024 -b c p 480 660 2050 480 2050 480 -b c p 481 660 2050 481 2050 481 -b c p 482 660 2050 482 2050 482 -b c p 483 660 2050 483 2050 483 -b c p 484 660 2050 484 2050 484 -b c p 485 660 2050 485 2050 485 -b c p 480 679 2050 480 2050 480 -b c p 481 679 2050 481 2050 481 -b c p 482 679 2050 482 2050 482 -b c p 483 679 2050 483 2050 483 -b c p 484 679 2050 484 2050 484 -b c p 485 679 2050 485 2050 485 -b c p 480 690 2050 480 2050 480 -b c p 481 690 2050 481 2050 481 -b c p 482 690 2050 482 2050 482 -b c p 483 690 2050 483 2050 483 -b c p 484 690 2050 484 2050 484 -b c p 485 690 2050 485 2050 485 -b c p 480 660 2048 480 2048 480 -b c p 481 660 2048 481 2048 481 -b c p 482 660 2048 482 2048 482 -b c p 483 660 2048 483 2048 483 -b c p 484 660 2048 484 2048 484 -b c p 485 660 2048 485 2048 485 -b c p 480 679 2048 480 2048 480 -b c p 481 679 2048 481 2048 481 -b c p 482 679 2048 482 2048 482 -b c p 483 679 2048 483 2048 483 -b c p 484 679 2048 484 2048 484 -b c p 485 679 2048 485 2048 485 -b c p 480 690 2048 480 2048 480 -b c p 481 690 2048 481 2048 481 -b c p 482 690 2048 482 2048 482 -b c p 483 690 2048 483 2048 483 -b c p 484 690 2048 484 2048 484 -b c p 485 690 2048 485 2048 485 -b c p 480 656 1024 480 1024 480 -b c p 480 128 3 480 3 480 -b c p 1024 512 515 1024 515 1024 -b c p 1024 2048 1024 1024 1024 1024 -b c p 1024 2048 515 1024 515 1024 -b c p 1024 1040 515 1024 515 1024 -b c p 5 1029 515 5 515 5 -b c p 1024 1029 515 1024 515 1024 -b c p 1024 1040 2050 1024 2050 1024 -b c p 1029 1029 2050 1029 2050 1029 -b c p 485 656 2050 485 2050 485 -b c p 480 672 2050 480 2050 480 -b c p 481 672 2050 481 2050 481 -b c p 482 672 2050 482 2050 482 -b c p 483 672 2050 483 2050 483 -b c p 484 672 2050 484 2050 484 -b c p 485 672 2050 485 2050 485 -b c p 480 688 2050 480 2050 480 -b c p 481 688 2050 481 2050 481 -b c p 1024 32 256 1024 256 1024 -b c P 1024 64 512 1024 512 1024 -b c P 64 800 320 64 320 64 -b c P 64 768 512 64 512 64 -b c P 16 256 512 16 512 16 -b c P 128 128 128 128 128 128 -b c P 256 512 256 256 256 256 -b c P 1024 1024 1024 1024 1024 1024 -b c P 480 640 1024 480 1024 480 -b c P 480 640 256 480 256 480 -b c P 8 64 32 8 32 8 -b c P 9 64 32 9 32 9 -b c P 10 128 64 10 64 10 -b c P 8 8 8 8 8 8 -b c P 12 12 12 12 12 12 -b c P 25 25 25 25 25 25 -b c P 25 25 20 25 20 25 -s r r 480 20 2050 2050 20 20 -s r r 481 20 2050 2050 20 20 -s r r 482 20 2050 2050 20 20 -s r p 483 20 2050 2050 20 20 -s r R 484 20 2050 2050 20 20 -s r R 485 20 2050 2050 20 20 -s r R 480 39 2050 2050 39 39 -s r R 481 39 2050 2050 39 39 -s r R 482 39 2050 2050 39 39 -s r R 483 39 2050 2050 39 39 -s r R 484 39 2050 2050 39 39 -s r p 485 39 2050 2050 39 39 -s r p 480 50 2050 2050 50 50 -s r p 481 50 2050 2050 50 50 -s r p 482 50 2050 2050 50 50 -s r p 483 50 2050 2050 50 50 -s r p 484 50 2050 2050 50 50 -s r p 485 50 2050 2050 50 50 -s r R 480 1108 2050 2050 1108 1108 -s r R 481 1108 2050 2050 1108 1108 -s r R 482 1108 2050 2050 1108 1108 -s r R 483 1108 2050 2050 1108 1108 -s r R 484 1108 2050 2050 1108 1108 -s r R 485 1108 2050 2050 1108 1108 -s r R 480 1127 2050 2050 1127 1127 -s r R 481 1127 2050 2050 1127 1127 -s r R 482 1127 2050 2050 1127 1127 -s r R 483 1127 2050 2050 1127 1127 -s r p 484 1127 2050 2050 1127 1127 -s r p 485 1127 2050 2050 1127 1127 -s r p 480 1138 2050 2050 1138 1138 -s r p 481 1138 2050 2050 1138 1138 -s r p 482 1138 2050 2050 1138 1138 -s r p 483 1138 2050 2050 1138 1138 -s r p 484 1138 2050 2050 1138 1138 -s r p 485 1138 2050 2050 1138 1138 -s r p 1 1 3 3 1 1 -s r p 1 9 3 3 9 9 -s r p 1 2048 3 3 2048 2048 -s r p 1 2048 5192 5192 2048 2048 -s r p 9 1 3 3 1 1 -s r p 576 1 3500 3500 1 1 -s r p 1 1 1 1 1 1 -s r p 102 1088 1024 1024 1088 1088 -s r p 102 2048 1024 1024 2048 2048 -s r p 485 656 1024 1024 656 656 -s r p 483 656 1024 1024 656 656 -s r p 81 128 3 3 128 128 -s r p 1022 512 515 515 512 512 -s r p 74 512 515 515 512 512 -s r p 253 2048 515 515 2048 2048 -s r p 8192 1040 515 515 1040 1040 -s r p 10 1029 515 515 1029 1029 -s r p 24 1040 2050 2050 1040 1040 -s r p 1024 1029 2050 2050 1029 1029 -s r p 480 660 2050 2050 660 660 -s r p 481 660 2050 2050 660 660 -s r p 482 660 2050 2050 660 660 -s r p 483 660 2050 2050 660 660 -s r p 484 660 2050 2050 660 660 -s r p 485 660 2050 2050 660 660 -s r p 480 679 2050 2050 679 679 -s r p 481 679 2050 2050 679 679 -s r p 482 679 2050 2050 679 679 -s r p 483 679 2050 2050 679 679 -s r p 484 679 2050 2050 679 679 -s r p 485 679 2050 2050 679 679 -s r p 480 690 2050 2050 690 690 -s r p 481 690 2050 2050 690 690 -s r p 482 690 2050 2050 690 690 -s r p 483 690 2050 2050 690 690 -s r p 484 690 2050 2050 690 690 -s r p 485 690 2050 2050 690 690 -s r p 480 660 2048 2048 660 660 -s r p 481 660 2048 2048 660 660 -s r p 482 660 2048 2048 660 660 -s r p 483 660 2048 2048 660 660 -s r p 484 660 2048 2048 660 660 -s r p 485 660 2048 2048 660 660 -s r p 480 679 2048 2048 679 679 -s r p 481 679 2048 2048 679 679 -s r p 482 679 2048 2048 679 679 -s r p 483 679 2048 2048 679 679 -s r p 484 679 2048 2048 679 679 -s r p 485 679 2048 2048 679 679 -s r p 480 690 2048 2048 690 690 -s r p 481 690 2048 2048 690 690 -s r p 482 690 2048 2048 690 690 -s r p 483 690 2048 2048 690 690 -s r p 484 690 2048 2048 690 690 -s r p 485 690 2048 2048 690 690 -s r p 480 656 1024 1024 656 656 -s r p 480 128 3 3 128 128 -s r p 1024 512 515 515 512 512 -s r p 1024 2048 1024 1024 2048 2048 -s r p 1024 2048 515 515 2048 2048 -s r p 1024 1040 515 515 1040 1040 -s r p 5 1029 515 515 1029 1029 -s r p 1024 1029 515 515 1029 1029 -s r p 1024 1040 2050 2050 1040 1040 -s r p 1029 1029 2050 2050 1029 1029 -s r R 480 646 2050 2050 646 646 -s r R 481 646 2050 2050 646 646 -s r R 482 646 2050 2050 646 646 -s r R 483 646 2050 2050 646 646 -s r R 484 646 2050 2050 646 646 -s r R 485 646 2050 2050 646 646 -s r R 481 656 2050 2050 656 656 -s r R 482 656 2050 2050 656 656 -s r R 483 656 2050 2050 656 656 -s r R 484 656 2050 2050 656 656 -s r p 485 656 2050 2050 656 656 -s r p 480 672 2050 2050 672 672 -s r p 481 672 2050 2050 672 672 -s r p 482 672 2050 2050 672 672 -s r p 483 672 2050 2050 672 672 -s r p 484 672 2050 2050 672 672 -s r p 485 672 2050 2050 672 672 -s r p 480 688 2050 2050 688 688 -s r p 481 688 2050 2050 688 688 -s r r 482 688 2050 2050 688 688 -s r r 483 688 2050 2050 688 688 -s r r 484 688 2050 2050 688 688 -s r r 485 688 2050 2050 688 688 -s r r 1024 512 64 64 512 512 -s r r 16 256 512 512 256 256 -s r r 480 640 512 512 640 640 -s r r 64 768 512 512 768 768 -s r r 128 128 128 128 128 128 -s r r 1024 64 512 512 64 64 -s r r 1024 256 32 32 256 256 -s r r 1024 512 64 64 512 512 -s r r 480 640 512 512 640 640 -s r p 1024 32 256 256 32 32 -s r P 1024 64 512 512 64 64 -s r P 64 800 320 320 800 800 -s r P 64 768 512 512 768 768 -s r P 16 256 512 512 256 256 -s r P 128 128 128 128 128 128 -s r P 256 512 256 256 512 512 -s r P 1024 1024 1024 1024 1024 1024 -s r P 480 640 1024 1024 640 640 -s r P 480 640 256 256 640 640 -s r P 8 64 32 32 64 64 -s r P 9 64 32 32 64 64 -s r P 10 128 64 64 128 128 -s r P 8 8 8 8 8 8 -s r P 12 12 12 12 12 12 -s r P 25 25 25 25 25 25 -s r P 25 25 20 20 25 25 -i r p 480 20 2050 2050 20 20 -i r p 481 20 2050 2050 20 20 -i r p 482 20 2050 2050 20 20 -i r p 483 20 2050 2050 20 20 -i r R 484 20 2050 2050 20 20 -i r R 485 20 2050 2050 20 20 -i r R 480 39 2050 2050 39 39 -i r R 481 39 2050 2050 39 39 -i r R 482 39 2050 2050 39 39 -i r R 483 39 2050 2050 39 39 -i r R 484 39 2050 2050 39 39 -i r p 485 39 2050 2050 39 39 -i r p 480 50 2050 2050 50 50 -i r p 481 50 2050 2050 50 50 -i r p 482 50 2050 2050 50 50 -i r p 483 50 2050 2050 50 50 -i r p 484 50 2050 2050 50 50 -i r p 485 50 2050 2050 50 50 -i r R 480 1108 2050 2050 1108 1108 -i r R 481 1108 2050 2050 1108 1108 -i r R 482 1108 2050 2050 1108 1108 -i r R 483 1108 2050 2050 1108 1108 -i r R 484 1108 2050 2050 1108 1108 -i r R 485 1108 2050 2050 1108 1108 -i r R 480 1127 2050 2050 1127 1127 -i r R 481 1127 2050 2050 1127 1127 -i r R 482 1127 2050 2050 1127 1127 -i r R 483 1127 2050 2050 1127 1127 -i r p 484 1127 2050 2050 1127 1127 -i r p 485 1127 2050 2050 1127 1127 -i r p 480 1138 2050 2050 1138 1138 -i r p 481 1138 2050 2050 1138 1138 -i r p 482 1138 2050 2050 1138 1138 -i r p 483 1138 2050 2050 1138 1138 -i r p 484 1138 2050 2050 1138 1138 -i r p 485 1138 2050 2050 1138 1138 -i r p 1 1 3 3 1 1 -i r p 1 9 3 3 9 9 -i r p 1 2048 3 3 2048 2048 -i r p 1 2048 5192 5192 2048 2048 -i r p 9 1 3 3 1 1 -i r p 576 1 3500 3500 1 1 -i r p 1 1 1 1 1 1 -i r p 102 1088 1024 1024 1088 1088 -i r p 102 2048 1024 1024 2048 2048 -i r p 485 656 1024 1024 656 656 -i r p 483 656 1024 1024 656 656 -i r p 81 128 3 3 128 128 -i r p 1022 512 515 515 512 512 -i r p 74 512 515 515 512 512 -i r p 253 2048 515 515 2048 2048 -i r p 8192 1040 515 515 1040 1040 -i r p 10 1029 515 515 1029 1029 -i r p 24 1040 2050 2050 1040 1040 -i r p 1024 1029 2050 2050 1029 1029 -i r p 480 660 2050 2050 660 660 -i r p 481 660 2050 2050 660 660 -i r p 482 660 2050 2050 660 660 -i r p 483 660 2050 2050 660 660 -i r p 484 660 2050 2050 660 660 -i r p 485 660 2050 2050 660 660 -i r p 480 679 2050 2050 679 679 -i r p 481 679 2050 2050 679 679 -i r p 482 679 2050 2050 679 679 -i r p 483 679 2050 2050 679 679 -i r p 484 679 2050 2050 679 679 -i r p 485 679 2050 2050 679 679 -i r p 480 690 2050 2050 690 690 -i r p 481 690 2050 2050 690 690 -i r p 482 690 2050 2050 690 690 -i r p 483 690 2050 2050 690 690 -i r p 484 690 2050 2050 690 690 -i r p 485 690 2050 2050 690 690 -i r p 480 660 2048 2048 660 660 -i r p 481 660 2048 2048 660 660 -i r p 482 660 2048 2048 660 660 -i r p 483 660 2048 2048 660 660 -i r p 484 660 2048 2048 660 660 -i r p 485 660 2048 2048 660 660 -i r p 480 679 2048 2048 679 679 -i r p 481 679 2048 2048 679 679 -i r p 482 679 2048 2048 679 679 -i r p 483 679 2048 2048 679 679 -i r p 484 679 2048 2048 679 679 -i r p 485 679 2048 2048 679 679 -i r p 480 690 2048 2048 690 690 -i r p 481 690 2048 2048 690 690 -i r p 482 690 2048 2048 690 690 -i r p 483 690 2048 2048 690 690 -i r p 484 690 2048 2048 690 690 -i r p 485 690 2048 2048 690 690 -i r p 480 656 1024 1024 656 656 -i r p 480 128 3 3 128 128 -i r p 1024 512 515 515 512 512 -i r p 1024 2048 1024 1024 2048 2048 -i r p 1024 2048 515 515 2048 2048 -i r p 1024 1040 515 515 1040 1040 -i r p 5 1029 515 515 1029 1029 -i r p 1024 1029 515 515 1029 1029 -i r p 1024 1040 2050 2050 1040 1040 -i r p 1029 1029 2050 2050 1029 1029 -i r R 480 646 2050 2050 646 646 -i r R 481 646 2050 2050 646 646 -i r R 482 646 2050 2050 646 646 -i r R 483 646 2050 2050 646 646 -i r R 484 646 2050 2050 646 646 -i r R 485 646 2050 2050 646 646 -i r R 481 656 2050 2050 656 656 -i r R 482 656 2050 2050 656 656 -i r R 483 656 2050 2050 656 656 -i r R 484 656 2050 2050 656 656 -i r p 485 656 2050 2050 656 656 -i r p 480 672 2050 2050 672 672 -i r p 481 672 2050 2050 672 672 -i r p 482 672 2050 2050 672 672 -i r p 483 672 2050 2050 672 672 -i r p 484 672 2050 2050 672 672 -i r p 485 672 2050 2050 672 672 -i r p 480 688 2050 2050 688 688 -i r p 481 688 2050 2050 688 688 -i r r 482 688 2050 2050 688 688 -i r r 483 688 2050 2050 688 688 -i r r 484 688 2050 2050 688 688 -i r r 485 688 2050 2050 688 688 -i r r 1024 512 64 64 512 512 -i r r 16 256 512 512 256 256 -i r r 480 640 512 512 640 640 -i r r 64 768 512 512 768 768 -i r r 128 128 128 128 128 128 -i r r 1024 64 512 512 64 64 -i r r 1024 256 32 32 256 256 -i r r 1024 512 64 64 512 512 -i r r 480 640 512 512 640 640 -i r p 1024 32 256 256 32 32 -i r P 1024 64 512 512 64 64 -i r P 64 800 320 320 800 800 -i r P 64 768 512 512 768 768 -i r P 16 256 512 512 256 256 -i r P 128 128 128 128 128 128 -i r P 256 512 256 256 512 512 -i r P 1024 1024 1024 1024 1024 1024 -i r P 480 640 1024 1024 640 640 -i r P 480 640 256 256 640 640 -i r P 8 64 32 32 64 64 -i r P 9 64 32 32 64 64 -i r P 10 128 64 64 128 128 -i r P 8 8 8 8 8 8 -i r P 12 12 12 12 12 12 -i r P 25 25 25 25 25 25 -i r P 25 25 20 20 25 25 -f r p 480 20 2050 2050 20 20 -f r p 481 20 2050 2050 20 20 -f r p 482 20 2050 2050 20 20 -f r p 483 20 2050 2050 20 20 -f r R 484 20 2050 2050 20 20 -f r R 485 20 2050 2050 20 20 -f r R 480 39 2050 2050 39 39 -f r R 481 39 2050 2050 39 39 -f r R 482 39 2050 2050 39 39 -f r R 483 39 2050 2050 39 39 -f r R 484 39 2050 2050 39 39 -f r p 485 39 2050 2050 39 39 -f r p 480 50 2050 2050 50 50 -f r p 481 50 2050 2050 50 50 -f r p 482 50 2050 2050 50 50 -f r p 483 50 2050 2050 50 50 -f r p 484 50 2050 2050 50 50 -f r p 485 50 2050 2050 50 50 -f r R 480 1108 2050 2050 1108 1108 -f r R 481 1108 2050 2050 1108 1108 -f r R 482 1108 2050 2050 1108 1108 -f r R 483 1108 2050 2050 1108 1108 -f r R 484 1108 2050 2050 1108 1108 -f r R 485 1108 2050 2050 1108 1108 -f r R 480 1127 2050 2050 1127 1127 -f r R 481 1127 2050 2050 1127 1127 -f r R 482 1127 2050 2050 1127 1127 -f r R 483 1127 2050 2050 1127 1127 -f r p 484 1127 2050 2050 1127 1127 -f r p 485 1127 2050 2050 1127 1127 -f r p 480 1138 2050 2050 1138 1138 -f r p 481 1138 2050 2050 1138 1138 -f r p 482 1138 2050 2050 1138 1138 -f r p 483 1138 2050 2050 1138 1138 -f r p 484 1138 2050 2050 1138 1138 -f r p 485 1138 2050 2050 1138 1138 -f r p 1 1 3 3 1 1 -f r p 1 9 3 3 9 9 -f r p 1 2048 3 3 2048 2048 -f r p 1 2048 5192 5192 2048 2048 -f r p 9 1 3 3 1 1 -f r p 576 1 3500 3500 1 1 -f r p 1 1 1 1 1 1 -f r p 102 1088 1024 1024 1088 1088 -f r p 102 2048 1024 1024 2048 2048 -f r p 485 656 1024 1024 656 656 -f r p 483 656 1024 1024 656 656 -f r p 81 128 3 3 128 128 -f r p 1022 512 515 515 512 512 -f r p 74 512 515 515 512 512 -f r p 253 2048 515 515 2048 2048 -f r p 8192 1040 515 515 1040 1040 -f r p 10 1029 515 515 1029 1029 -f r p 24 1040 2050 2050 1040 1040 -f r p 1024 1029 2050 2050 1029 1029 -f r p 480 660 2050 2050 660 660 -f r p 481 660 2050 2050 660 660 -f r p 482 660 2050 2050 660 660 -f r p 483 660 2050 2050 660 660 -f r p 484 660 2050 2050 660 660 -f r p 485 660 2050 2050 660 660 -f r p 480 679 2050 2050 679 679 -f r p 481 679 2050 2050 679 679 -f r p 482 679 2050 2050 679 679 -f r p 483 679 2050 2050 679 679 -f r p 484 679 2050 2050 679 679 -f r p 485 679 2050 2050 679 679 -f r p 480 690 2050 2050 690 690 -f r p 481 690 2050 2050 690 690 -f r p 482 690 2050 2050 690 690 -f r p 483 690 2050 2050 690 690 -f r p 484 690 2050 2050 690 690 -f r p 485 690 2050 2050 690 690 -f r p 480 660 2048 2048 660 660 -f r p 481 660 2048 2048 660 660 -f r p 482 660 2048 2048 660 660 -f r p 483 660 2048 2048 660 660 -f r p 484 660 2048 2048 660 660 -f r p 485 660 2048 2048 660 660 -f r p 480 679 2048 2048 679 679 -f r p 481 679 2048 2048 679 679 -f r p 482 679 2048 2048 679 679 -f r p 483 679 2048 2048 679 679 -f r p 484 679 2048 2048 679 679 -f r p 485 679 2048 2048 679 679 -f r p 480 690 2048 2048 690 690 -f r p 481 690 2048 2048 690 690 -f r p 482 690 2048 2048 690 690 -f r p 483 690 2048 2048 690 690 -f r p 484 690 2048 2048 690 690 -f r p 485 690 2048 2048 690 690 -f r p 480 656 1024 1024 656 656 -f r p 480 128 3 3 128 128 -f r p 1024 512 515 515 512 512 -f r p 1024 2048 1024 1024 2048 2048 -f r p 1024 2048 515 515 2048 2048 -f r p 1024 1040 515 515 1040 1040 -f r p 5 1029 515 515 1029 1029 -f r p 1024 1029 515 515 1029 1029 -f r p 1024 1040 2050 2050 1040 1040 -f r p 1029 1029 2050 2050 1029 1029 -f r R 480 646 2050 2050 646 646 -f r R 481 646 2050 2050 646 646 -f r R 482 646 2050 2050 646 646 -f r R 483 646 2050 2050 646 646 -f r R 484 646 2050 2050 646 646 -f r R 485 646 2050 2050 646 646 -f r R 481 656 2050 2050 656 656 -f r R 482 656 2050 2050 656 656 -f r R 483 656 2050 2050 656 656 -f r R 484 656 2050 2050 656 656 -f r p 485 656 2050 2050 656 656 -f r p 480 672 2050 2050 672 672 -f r p 481 672 2050 2050 672 672 -f r p 482 672 2050 2050 672 672 -f r p 483 672 2050 2050 672 672 -f r p 484 672 2050 2050 672 672 -f r p 485 672 2050 2050 672 672 -f r p 480 688 2050 2050 688 688 -f r p 481 688 2050 2050 688 688 -f r r 482 688 2050 2050 688 688 -f r r 483 688 2050 2050 688 688 -f r r 484 688 2050 2050 688 688 -f r r 485 688 2050 2050 688 688 -f r r 1024 512 64 64 512 512 -f r r 16 256 512 512 256 256 -f r r 480 640 512 512 640 640 -f r r 64 768 512 512 768 768 -f r r 128 128 128 128 128 128 -f r r 1024 64 512 512 64 64 -f r r 1024 256 32 32 256 256 -f r r 1024 512 64 64 512 512 -f r r 480 640 512 512 640 640 -f r p 1024 32 256 256 32 32 -f r P 1024 64 512 512 64 64 -f r P 64 800 320 320 800 800 -f r P 64 768 512 512 768 768 -f r P 16 256 512 512 256 256 -f r P 128 128 128 128 128 128 -f r P 256 512 256 256 512 512 -f r P 1024 1024 1024 1024 1024 1024 -f r P 480 640 1024 1024 640 640 -f r P 480 640 256 256 640 640 -f r P 8 64 32 32 64 64 -f r P 9 64 32 32 64 64 -f r P 10 128 64 64 128 128 -f r P 8 8 8 8 8 8 -f r P 12 12 12 12 12 12 -f r P 25 25 25 25 25 25 -f r P 25 25 20 20 25 25 -i r r 4096 256 5 5 256 256 -i r r 3000 256 128 128 256 256 -i r r 4096 1024 512 512 1024 1024 -i r r 144 256 5 5 256 256 -i r r 144 256 128 128 256 256 -i r r 144 1024 512 512 1024 1024 -i r r 480 688 256 256 688 688 -i r r 480 640 512 512 640 640 -i r r 480 640 1024 1024 640 640 -i r r 64 800 320 320 800 800 -i r r 64 768 512 512 768 768 -i r r 16 256 512 512 256 256 -i r r 128 128 128 128 128 128 -i r r 256 512 256 256 512 512 -i r r 1024 1024 1024 1024 1024 1024 -i r r 1024 32 256 256 32 32 -i r r 1024 64 512 512 64 64 -i r r 1024 256 32 32 256 256 -i r r 1024 512 64 64 512 512 -i r r 512 32 256 256 32 32 -i r r 512 768 512 512 768 768 -i r r 512 256 32 32 256 256 -i r r 512 512 64 64 512 512 -i r r 512 256 768 768 256 256 -i r r 768 768 1024 1024 768 768 -i r r 768 768 768 768 768 768 -i r r 2048 2048 2048 2048 2048 2048 -i r r 4096 4096 4096 4096 4096 4096 -f r r 4096 256 5 5 256 256 -f r r 3000 256 128 128 256 256 -f r r 4096 1024 512 512 1024 1024 -f r r 144 256 5 5 256 256 -f r r 144 256 128 128 256 256 -f r r 144 1024 512 512 1024 1024 -f r r 480 688 256 256 688 688 -f r r 480 640 512 512 640 640 -f r r 480 640 1024 1024 640 640 -f r r 64 800 320 320 800 800 -f r r 64 768 512 512 768 768 -f r r 16 256 512 512 256 256 -f r r 128 128 128 128 128 128 -f r r 256 512 256 256 512 512 -f r r 1024 1024 1024 1024 1024 1024 -f r r 1024 32 256 256 32 32 -f r r 1024 64 512 512 64 64 -f r r 1024 256 32 32 256 256 -f r r 1024 512 64 64 512 512 -f r r 512 32 256 256 32 32 -f r r 512 768 512 512 768 768 -f r r 512 256 32 32 256 256 -f r r 512 512 64 64 512 512 -f r r 512 256 768 768 256 256 -f r r 768 768 1024 1024 768 768 -f r r 768 768 768 768 768 768 -f r r 2048 2048 2048 2048 2048 2048 -f r r 4096 4096 4096 4096 4096 4096 -f r r 2048 1024 1024 1024 1024 1024 -f r r 2048 4096 1024 1024 4096 4096 -f r r 2048 1024 4096 4096 1024 1024 -f r r 2048 1024 2 2 1024 1024 -f r r 128 1024 1024 1024 1024 1024 -f r r 1536 768 768 768 768 768 -f r r 1536 3072 768 768 3072 3072 -f r r 1536 768 3072 3072 768 768 -f r r 1536 768 2 2 768 768 -f r r 128 768 768 768 768 768 -f r r 1024 8 13 13 8 8 -f r r 1024 4 8 8 4 4 -f r r 1024 128 355 355 128 128 -f r r 1024 64 128 128 64 64 -f r r 1024 1 64 64 1 1 -f r r 480 1 256 256 1 1 -f r r 480 256 512 512 256 256 -f r r 480 1024 845 845 1024 1024 -f r r 480 512 1024 1024 512 512 -f r r 10 17191 128 128 17191 17191 -f r r 10 512 256 256 512 512 +u s32 r n n n p 480 20 2050 2050 20 20 none +u s8 r n n n p 480 20 2050 2050 20 20 none +u s32 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n p 481 20 2050 2050 20 20 none +u s8 r n n n p 481 20 2050 2050 20 20 none +u s32 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n p 482 20 2050 2050 20 20 none +u s8 r n n n p 482 20 2050 2050 20 20 none +u s32 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n p 483 20 2050 2050 20 20 none +u s8 r n n n p 483 20 2050 2050 20 20 none +u s32 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n R 484 20 2050 2050 20 20 none +u s8 r n n n R 484 20 2050 2050 20 20 none +u s32 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n R 485 20 2050 2050 20 20 none +u s8 r n n n R 485 20 2050 2050 20 20 none +u s32 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +u s8 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +u s32 r n n n R 480 39 2050 2050 39 39 none +u s8 r n n n R 480 39 2050 2050 39 39 none +u s32 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n R 481 39 2050 2050 39 39 none +u s8 r n n n R 481 39 2050 2050 39 39 none +u s32 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n R 482 39 2050 2050 39 39 none +u s8 r n n n R 482 39 2050 2050 39 39 none +u s32 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n R 483 39 2050 2050 39 39 none +u s8 r n n n R 483 39 2050 2050 39 39 none +u s32 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n R 484 39 2050 2050 39 39 none +u s8 r n n n R 484 39 2050 2050 39 39 none +u s32 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n p 485 39 2050 2050 39 39 none +u s8 r n n n p 485 39 2050 2050 39 39 none +u s32 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +u s8 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +u s32 r n n n p 480 50 2050 2050 50 50 none +u s8 r n n n p 480 50 2050 2050 50 50 none +u s32 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n p 481 50 2050 2050 50 50 none +u s8 r n n n p 481 50 2050 2050 50 50 none +u s32 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n p 482 50 2050 2050 50 50 none +u s8 r n n n p 482 50 2050 2050 50 50 none +u s32 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n p 483 50 2050 2050 50 50 none +u s8 r n n n p 483 50 2050 2050 50 50 none +u s32 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n p 484 50 2050 2050 50 50 none +u s8 r n n n p 484 50 2050 2050 50 50 none +u s32 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n p 485 50 2050 2050 50 50 none +u s8 r n n n p 485 50 2050 2050 50 50 none +u s32 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +u s8 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +u s32 r n n n R 480 1108 2050 2050 1108 1108 none +u s8 r n n n R 480 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 481 1108 2050 2050 1108 1108 none +u s8 r n n n R 481 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 482 1108 2050 2050 1108 1108 none +u s8 r n n n R 482 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 483 1108 2050 2050 1108 1108 none +u s8 r n n n R 483 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 484 1108 2050 2050 1108 1108 none +u s8 r n n n R 484 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 485 1108 2050 2050 1108 1108 none +u s8 r n n n R 485 1108 2050 2050 1108 1108 none +u s32 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +u s8 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +u s32 r n n n R 480 1127 2050 2050 1127 1127 none +u s8 r n n n R 480 1127 2050 2050 1127 1127 none +u s32 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n R 481 1127 2050 2050 1127 1127 none +u s8 r n n n R 481 1127 2050 2050 1127 1127 none +u s32 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n R 482 1127 2050 2050 1127 1127 none +u s8 r n n n R 482 1127 2050 2050 1127 1127 none +u s32 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n R 483 1127 2050 2050 1127 1127 none +u s8 r n n n R 483 1127 2050 2050 1127 1127 none +u s32 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n p 484 1127 2050 2050 1127 1127 none +u s8 r n n n p 484 1127 2050 2050 1127 1127 none +u s32 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n p 485 1127 2050 2050 1127 1127 none +u s8 r n n n p 485 1127 2050 2050 1127 1127 none +u s32 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +u s8 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +u s32 r n n n p 480 1138 2050 2050 1138 1138 none +u s8 r n n n p 480 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 481 1138 2050 2050 1138 1138 none +u s8 r n n n p 481 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 482 1138 2050 2050 1138 1138 none +u s8 r n n n p 482 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 483 1138 2050 2050 1138 1138 none +u s8 r n n n p 483 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 484 1138 2050 2050 1138 1138 none +u s8 r n n n p 484 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 485 1138 2050 2050 1138 1138 none +u s8 r n n n p 485 1138 2050 2050 1138 1138 none +u s32 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +u s8 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +u s32 r n n n p 1 1 3 3 1 1 none +u s8 r n n n p 1 1 3 3 1 1 none +u s32 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip +u s8 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip +u s32 r n n n p 1 9 3 3 9 9 none +u s8 r n n n p 1 9 3 3 9 9 none +u s32 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip +u s8 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip +u s32 r n n n p 1 2048 3 3 2048 2048 none +u s8 r n n n p 1 2048 3 3 2048 2048 none +u s32 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +u s8 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +u s32 r n n n p 1 2048 5192 5192 2048 2048 none +u s8 r n n n p 1 2048 5192 5192 2048 2048 none +u s32 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +u s8 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +u s32 r n n n p 9 1 3 3 1 1 none +u s8 r n n n p 9 1 3 3 1 1 none +u s32 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip +u s8 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip +u s32 r n n n p 576 1 3500 3500 1 1 none +u s8 r n n n p 576 1 3500 3500 1 1 none +u s32 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +u s8 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +u s32 r n n n p 1 1 1 1 1 1 none +u s8 r n n n p 1 1 1 1 1 1 none +u s32 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip +u s8 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip +u s32 r n n n p 102 1088 1024 1024 1088 1088 none +u s8 r n n n p 102 1088 1024 1024 1088 1088 none +u s32 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +u s8 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +u s32 r n n n p 102 2048 1024 1024 2048 2048 none +u s8 r n n n p 102 2048 1024 1024 2048 2048 none +u s32 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +u s8 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +u s32 r n n n p 485 656 1024 1024 656 656 none +u s8 r n n n p 485 656 1024 1024 656 656 none +u s32 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +u s8 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +u s32 r n n n p 483 656 1024 1024 656 656 none +u s8 r n n n p 483 656 1024 1024 656 656 none +u s32 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +u s8 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +u s32 r n n n p 81 128 3 3 128 128 none +u s8 r n n n p 81 128 3 3 128 128 none +u s32 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip +u s8 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip +u s32 r n n n p 1022 512 515 515 512 512 none +u s8 r n n n p 1022 512 515 515 512 512 none +u s32 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +u s8 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +u s32 r n n n p 74 512 515 515 512 512 none +u s8 r n n n p 74 512 515 515 512 512 none +u s32 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip +u s8 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip +u s32 r n n n p 253 2048 515 515 2048 2048 none +u s8 r n n n p 253 2048 515 515 2048 2048 none +u s32 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +u s8 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +u s32 r n n n p 8192 1040 515 515 1040 1040 none +u s8 r n n n p 8192 1040 515 515 1040 1040 none +u s32 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +u s8 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +u s32 r n n n p 10 1029 515 515 1029 1029 none +u s8 r n n n p 10 1029 515 515 1029 1029 none +u s32 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +u s8 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +u s32 r n n n p 24 1040 2050 2050 1040 1040 none +u s8 r n n n p 24 1040 2050 2050 1040 1040 none +u s32 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +u s8 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +u s32 r n n n p 1024 1029 2050 2050 1029 1029 none +u s8 r n n n p 1024 1029 2050 2050 1029 1029 none +u s32 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +u s8 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +u s32 r n n n p 480 660 2050 2050 660 660 none +u s8 r n n n p 480 660 2050 2050 660 660 none +u s32 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 481 660 2050 2050 660 660 none +u s8 r n n n p 481 660 2050 2050 660 660 none +u s32 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 482 660 2050 2050 660 660 none +u s8 r n n n p 482 660 2050 2050 660 660 none +u s32 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 483 660 2050 2050 660 660 none +u s8 r n n n p 483 660 2050 2050 660 660 none +u s32 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 484 660 2050 2050 660 660 none +u s8 r n n n p 484 660 2050 2050 660 660 none +u s32 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 485 660 2050 2050 660 660 none +u s8 r n n n p 485 660 2050 2050 660 660 none +u s32 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +u s8 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +u s32 r n n n p 480 679 2050 2050 679 679 none +u s8 r n n n p 480 679 2050 2050 679 679 none +u s32 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 481 679 2050 2050 679 679 none +u s8 r n n n p 481 679 2050 2050 679 679 none +u s32 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 482 679 2050 2050 679 679 none +u s8 r n n n p 482 679 2050 2050 679 679 none +u s32 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 483 679 2050 2050 679 679 none +u s8 r n n n p 483 679 2050 2050 679 679 none +u s32 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 484 679 2050 2050 679 679 none +u s8 r n n n p 484 679 2050 2050 679 679 none +u s32 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 485 679 2050 2050 679 679 none +u s8 r n n n p 485 679 2050 2050 679 679 none +u s32 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +u s8 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +u s32 r n n n p 480 690 2050 2050 690 690 none +u s8 r n n n p 480 690 2050 2050 690 690 none +u s32 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 481 690 2050 2050 690 690 none +u s8 r n n n p 481 690 2050 2050 690 690 none +u s32 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 482 690 2050 2050 690 690 none +u s8 r n n n p 482 690 2050 2050 690 690 none +u s32 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 483 690 2050 2050 690 690 none +u s8 r n n n p 483 690 2050 2050 690 690 none +u s32 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 484 690 2050 2050 690 690 none +u s8 r n n n p 484 690 2050 2050 690 690 none +u s32 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 485 690 2050 2050 690 690 none +u s8 r n n n p 485 690 2050 2050 690 690 none +u s32 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +u s8 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +u s32 r n n n p 480 660 2048 2048 660 660 none +u s8 r n n n p 480 660 2048 2048 660 660 none +u s32 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 481 660 2048 2048 660 660 none +u s8 r n n n p 481 660 2048 2048 660 660 none +u s32 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 482 660 2048 2048 660 660 none +u s8 r n n n p 482 660 2048 2048 660 660 none +u s32 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 483 660 2048 2048 660 660 none +u s8 r n n n p 483 660 2048 2048 660 660 none +u s32 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 484 660 2048 2048 660 660 none +u s8 r n n n p 484 660 2048 2048 660 660 none +u s32 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 485 660 2048 2048 660 660 none +u s8 r n n n p 485 660 2048 2048 660 660 none +u s32 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +u s8 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +u s32 r n n n p 480 679 2048 2048 679 679 none +u s8 r n n n p 480 679 2048 2048 679 679 none +u s32 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 481 679 2048 2048 679 679 none +u s8 r n n n p 481 679 2048 2048 679 679 none +u s32 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 482 679 2048 2048 679 679 none +u s8 r n n n p 482 679 2048 2048 679 679 none +u s32 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 483 679 2048 2048 679 679 none +u s8 r n n n p 483 679 2048 2048 679 679 none +u s32 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 484 679 2048 2048 679 679 none +u s8 r n n n p 484 679 2048 2048 679 679 none +u s32 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 485 679 2048 2048 679 679 none +u s8 r n n n p 485 679 2048 2048 679 679 none +u s32 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +u s8 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +u s32 r n n n p 480 690 2048 2048 690 690 none +u s8 r n n n p 480 690 2048 2048 690 690 none +u s32 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 481 690 2048 2048 690 690 none +u s8 r n n n p 481 690 2048 2048 690 690 none +u s32 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 482 690 2048 2048 690 690 none +u s8 r n n n p 482 690 2048 2048 690 690 none +u s32 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 483 690 2048 2048 690 690 none +u s8 r n n n p 483 690 2048 2048 690 690 none +u s32 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 484 690 2048 2048 690 690 none +u s8 r n n n p 484 690 2048 2048 690 690 none +u s32 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 485 690 2048 2048 690 690 none +u s8 r n n n p 485 690 2048 2048 690 690 none +u s32 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +u s8 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +u s32 r n n n p 480 656 1024 1024 656 656 none +u s8 r n n n p 480 656 1024 1024 656 656 none +u s32 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +u s8 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +u s32 r n n n p 480 128 3 3 128 128 none +u s8 r n n n p 480 128 3 3 128 128 none +u s32 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip +u s8 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip +u s32 r n n n p 1024 512 515 515 512 512 none +u s8 r n n n p 1024 512 515 515 512 512 none +u s32 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +u s8 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +u s32 r n n n p 1024 2048 1024 1024 2048 2048 none +u s8 r n n n p 1024 2048 1024 1024 2048 2048 none +u s32 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +u s8 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +u s32 r n n n p 1024 2048 515 515 2048 2048 none +u s8 r n n n p 1024 2048 515 515 2048 2048 none +u s32 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +u s8 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +u s32 r n n n p 1024 1040 515 515 1040 1040 none +u s8 r n n n p 1024 1040 515 515 1040 1040 none +u s32 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +u s8 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +u s32 r n n n p 5 1029 515 515 1029 1029 none +u s8 r n n n p 5 1029 515 515 1029 1029 none +u s32 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +u s8 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +u s32 r n n n p 1024 1029 515 515 1029 1029 none +u s8 r n n n p 1024 1029 515 515 1029 1029 none +u s32 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +u s8 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +u s32 r n n n p 1024 1040 2050 2050 1040 1040 none +u s8 r n n n p 1024 1040 2050 2050 1040 1040 none +u s32 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +u s8 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +u s32 r n n n p 1029 1029 2050 2050 1029 1029 none +u s8 r n n n p 1029 1029 2050 2050 1029 1029 none +u s32 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +u s8 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +u s32 r n n n R 480 646 2050 2050 646 646 none +u s8 r n n n R 480 646 2050 2050 646 646 none +u s32 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 481 646 2050 2050 646 646 none +u s8 r n n n R 481 646 2050 2050 646 646 none +u s32 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 482 646 2050 2050 646 646 none +u s8 r n n n R 482 646 2050 2050 646 646 none +u s32 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 483 646 2050 2050 646 646 none +u s8 r n n n R 483 646 2050 2050 646 646 none +u s32 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 484 646 2050 2050 646 646 none +u s8 r n n n R 484 646 2050 2050 646 646 none +u s32 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 485 646 2050 2050 646 646 none +u s8 r n n n R 485 646 2050 2050 646 646 none +u s32 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +u s8 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +u s32 r n n n R 481 656 2050 2050 656 656 none +u s8 r n n n R 481 656 2050 2050 656 656 none +u s32 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +u s8 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +u s32 r n n n R 482 656 2050 2050 656 656 none +u s8 r n n n R 482 656 2050 2050 656 656 none +u s32 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +u s8 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +u s32 r n n n R 483 656 2050 2050 656 656 none +u s8 r n n n R 483 656 2050 2050 656 656 none +u s32 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +u s8 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +u s32 r n n n R 484 656 2050 2050 656 656 none +u s8 r n n n R 484 656 2050 2050 656 656 none +u s32 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +u s8 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +u s32 r n n n p 485 656 2050 2050 656 656 none +u s8 r n n n p 485 656 2050 2050 656 656 none +u s32 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +u s8 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +u s32 r n n n p 480 672 2050 2050 672 672 none +u s8 r n n n p 480 672 2050 2050 672 672 none +u s32 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 481 672 2050 2050 672 672 none +u s8 r n n n p 481 672 2050 2050 672 672 none +u s32 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 482 672 2050 2050 672 672 none +u s8 r n n n p 482 672 2050 2050 672 672 none +u s32 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 483 672 2050 2050 672 672 none +u s8 r n n n p 483 672 2050 2050 672 672 none +u s32 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 484 672 2050 2050 672 672 none +u s8 r n n n p 484 672 2050 2050 672 672 none +u s32 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 485 672 2050 2050 672 672 none +u s8 r n n n p 485 672 2050 2050 672 672 none +u s32 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +u s8 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +u s32 r n n n p 480 688 2050 2050 688 688 none +u s8 r n n n p 480 688 2050 2050 688 688 none +u s32 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n p 481 688 2050 2050 688 688 none +u s8 r n n n p 481 688 2050 2050 688 688 none +u s32 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n r 482 688 2050 2050 688 688 none +u s8 r n n n r 482 688 2050 2050 688 688 none +u s32 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n r 483 688 2050 2050 688 688 none +u s8 r n n n r 483 688 2050 2050 688 688 none +u s32 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n r 484 688 2050 2050 688 688 none +u s8 r n n n r 484 688 2050 2050 688 688 none +u s32 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n r 485 688 2050 2050 688 688 none +u s8 r n n n r 485 688 2050 2050 688 688 none +u s32 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +u s8 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +u s32 r n n n r 1024 512 64 64 512 512 none +u s8 r n n n r 1024 512 64 64 512 512 none +u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s32 r n n n r 16 256 512 512 256 256 none +u s8 r n n n r 16 256 512 512 256 256 none +u s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +u s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +u s32 r n n n r 480 640 512 512 640 640 none +u s8 r n n n r 480 640 512 512 640 640 none +u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s32 r n n n r 64 768 512 512 768 768 none +u s8 r n n n r 64 768 512 512 768 768 none +u s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +u s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +u s32 r n n n r 128 128 128 128 128 128 none +u s8 r n n n r 128 128 128 128 128 128 none +u s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +u s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +u s32 r n n n r 1024 64 512 512 64 64 none +u s8 r n n n r 1024 64 512 512 64 64 none +u s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +u s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +u s32 r n n n r 1024 256 32 32 256 256 none +u s8 r n n n r 1024 256 32 32 256 256 none +u s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +u s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +u s32 r n n n r 1024 512 64 64 512 512 none +u s8 r n n n r 1024 512 64 64 512 512 none +u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s32 r n n n r 480 640 512 512 640 640 none +u s8 r n n n r 480 640 512 512 640 640 none +u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s32 r n n n p 1024 32 256 256 32 32 none +u s8 r n n n p 1024 32 256 256 32 32 none +u s32 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +u s8 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +u s32 r n n n P 1024 64 512 512 64 64 none +u s8 r n n n P 1024 64 512 512 64 64 none +u s32 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +u s8 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +u s32 r n n n P 64 800 320 320 800 800 none +u s8 r n n n P 64 800 320 320 800 800 none +u s32 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip +u s8 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip +u s32 r n n n P 64 768 512 512 768 768 none +u s8 r n n n P 64 768 512 512 768 768 none +u s32 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip +u s8 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip +u s32 r n n n P 16 256 512 512 256 256 none +u s8 r n n n P 16 256 512 512 256 256 none +u s32 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip +u s8 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip +u s32 r n n n P 128 128 128 128 128 128 none +u s8 r n n n P 128 128 128 128 128 128 none +u s32 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip +u s8 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip +u s32 r n n n P 256 512 256 256 512 512 none +u s8 r n n n P 256 512 256 256 512 512 none +u s32 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip +u s8 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip +u s32 r n n n P 1024 1024 1024 1024 1024 1024 none +u s8 r n n n P 1024 1024 1024 1024 1024 1024 none +u s32 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +u s8 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +u s32 r n n n P 480 640 1024 1024 640 640 none +u s8 r n n n P 480 640 1024 1024 640 640 none +u s32 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +u s8 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +u s32 r n n n P 480 640 256 256 640 640 none +u s8 r n n n P 480 640 256 256 640 640 none +u s32 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip +u s8 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip +u s32 r n n n P 8 64 32 32 64 64 none +u s8 r n n n P 8 64 32 32 64 64 none +u s32 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip +u s8 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip +u s32 r n n n P 9 64 32 32 64 64 none +u s8 r n n n P 9 64 32 32 64 64 none +u s32 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip +u s8 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip +u s32 r n n n P 10 128 64 64 128 128 none +u s8 r n n n P 10 128 64 64 128 128 none +u s32 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip +u s8 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip +u s32 r n n n P 8 8 8 8 8 8 none +u s8 r n n n P 8 8 8 8 8 8 none +u s32 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip +u s8 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip +u s32 r n n n P 12 12 12 12 12 12 none +u s8 r n n n P 12 12 12 12 12 12 none +u s32 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip +u s8 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip +u s32 r n n n P 25 25 25 25 25 25 none +u s8 r n n n P 25 25 25 25 25 25 none +u s32 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip +u s8 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip +u s32 r n n n P 25 25 20 20 25 25 none +u s8 r n n n P 25 25 20 20 25 25 none +u s32 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip +u s8 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip +u s32 r n n n r 4096 256 5 5 256 256 none +u s8 r n n n r 4096 256 5 5 256 256 none +u s32 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip +u s8 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip +u s32 r n n n r 3000 256 128 128 256 256 none +u s8 r n n n r 3000 256 128 128 256 256 none +u s32 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip +u s8 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip +u s32 r n n n r 4096 1024 512 512 1024 1024 none +u s8 r n n n r 4096 1024 512 512 1024 1024 none +u s32 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip +u s8 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip +u s32 r n n n r 144 256 5 5 256 256 none +u s8 r n n n r 144 256 5 5 256 256 none +u s32 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip +u s8 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip +u s32 r n n n r 144 256 128 128 256 256 none +u s8 r n n n r 144 256 128 128 256 256 none +u s32 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip +u s8 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip +u s32 r n n n r 144 1024 512 512 1024 1024 none +u s8 r n n n r 144 1024 512 512 1024 1024 none +u s32 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip +u s8 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip +u s32 r n n n r 480 688 256 256 688 688 none +u s8 r n n n r 480 688 256 256 688 688 none +u s32 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip +u s8 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip +u s32 r n n n r 480 640 512 512 640 640 none +u s8 r n n n r 480 640 512 512 640 640 none +u s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +u s32 r n n n r 480 640 1024 1024 640 640 none +u s8 r n n n r 480 640 1024 1024 640 640 none +u s32 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip +u s8 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip +u s32 r n n n r 64 800 320 320 800 800 none +u s8 r n n n r 64 800 320 320 800 800 none +u s32 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip +u s8 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip +u s32 r n n n r 64 768 512 512 768 768 none +u s8 r n n n r 64 768 512 512 768 768 none +u s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +u s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +u s32 r n n n r 16 256 512 512 256 256 none +u s8 r n n n r 16 256 512 512 256 256 none +u s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +u s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +u s32 r n n n r 128 128 128 128 128 128 none +u s8 r n n n r 128 128 128 128 128 128 none +u s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +u s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +u s32 r n n n r 256 512 256 256 512 512 none +u s8 r n n n r 256 512 256 256 512 512 none +u s32 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip +u s8 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip +u s32 r n n n r 1024 1024 1024 1024 1024 1024 none +u s8 r n n n r 1024 1024 1024 1024 1024 1024 none +u s32 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip +u s8 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip +u s32 r n n n r 1024 32 256 256 32 32 none +u s8 r n n n r 1024 32 256 256 32 32 none +u s32 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip +u s8 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip +u s32 r n n n r 1024 64 512 512 64 64 none +u s8 r n n n r 1024 64 512 512 64 64 none +u s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +u s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +u s32 r n n n r 1024 256 32 32 256 256 none +u s8 r n n n r 1024 256 32 32 256 256 none +u s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +u s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +u s32 r n n n r 1024 512 64 64 512 512 none +u s8 r n n n r 1024 512 64 64 512 512 none +u s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +u s32 r n n n r 512 32 256 256 32 32 none +u s8 r n n n r 512 32 256 256 32 32 none +u s32 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip +u s8 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip +u s32 r n n n r 512 768 512 512 768 768 none +u s8 r n n n r 512 768 512 512 768 768 none +u s32 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip +u s8 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip +u s32 r n n n r 512 256 32 32 256 256 none +u s8 r n n n r 512 256 32 32 256 256 none +u s32 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip +u s8 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip +u s32 r n n n r 512 512 64 64 512 512 none +u s8 r n n n r 512 512 64 64 512 512 none +u s32 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip +u s8 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip +u s32 r n n n r 512 256 768 768 256 256 none +u s8 r n n n r 512 256 768 768 256 256 none +u s32 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip +u s8 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip +u s32 r n n n r 768 768 1024 1024 768 768 none +u s8 r n n n r 768 768 1024 1024 768 768 none +u s32 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip +u s8 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip +u s32 r n n n r 768 768 768 768 768 768 none +u s8 r n n n r 768 768 768 768 768 768 none +u s32 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip +u s8 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip +u s32 r n n n r 2048 2048 2048 2048 2048 2048 none +u s8 r n n n r 2048 2048 2048 2048 2048 2048 none +u s32 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip +u s8 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip +u s32 r n n n r 4096 4096 4096 4096 4096 4096 none +u s8 r n n n r 4096 4096 4096 4096 4096 4096 none +u s32 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip +u s8 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip +f f32 c n n n p 2482 1127 2050 2482 2050 2482 none +f f32 f32 c n n n p 2482 1127 2050 2482 2050 2482 bias,relu,clip +f f32 c n n n p 2483 1127 2050 2483 2050 2483 none +f f32 f32 c n n n p 2483 1127 2050 2483 2050 2483 bias,relu,clip +f f32 c n n n p 2484 1127 2050 2484 2050 2484 none +f f32 f32 c n n n p 2484 1127 2050 2484 2050 2484 bias,relu,clip +f f32 c n n n p 2485 1127 2050 2485 2050 2485 none +f f32 f32 c n n n p 2485 1127 2050 2485 2050 2485 bias,relu,clip +f f32 c n n n p 480 1138 2050 480 2050 480 none +f f32 f32 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip +f f32 c n n n p 481 1138 2050 481 2050 481 none +f f32 f32 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip +f f32 c n n n p 482 1138 2050 482 2050 482 none +f f32 f32 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip +f f32 c n n n p 483 1138 2050 483 2050 483 none +f f32 f32 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip +f f32 c n n n p 484 1138 2050 484 2050 484 none +f f32 f32 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip +f f32 c n n n p 485 1138 2050 485 2050 485 none +f f32 f32 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip +f f32 c n n n p 1 1 3 1 3 1 none +f f32 f32 c n n n p 1 1 3 1 3 1 bias,relu,clip +f f32 c n n n p 1 9 3 1 3 1 none +f f32 f32 c n n n p 1 9 3 1 3 1 bias,relu,clip +f f32 c n n n p 1 2048 3 1 3 1 none +f f32 f32 c n n n p 1 2048 3 1 3 1 bias,relu,clip +f f32 c n n n p 1 2048 5192 1 5192 1 none +f f32 f32 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip +f f32 c n n n p 9 1 3 9 3 9 none +f f32 f32 c n n n p 9 1 3 9 3 9 bias,relu,clip +f f32 c n n n p 576 1 3500 576 3500 576 none +f f32 f32 c n n n p 576 1 3500 576 3500 576 bias,relu,clip +f f32 c n n n p 1 1 1 1 1 1 none +f f32 f32 c n n n p 1 1 1 1 1 1 bias,relu,clip +f f32 c n n n p 102 1088 1024 102 1024 102 none +f f32 f32 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip +b f32 r n n n r 480 20 2050 2050 20 20 none +b bf16 r n n n r 480 20 2050 2050 20 20 none +b f32 bf16 r n n n r 480 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n r 480 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n r 481 20 2050 2050 20 20 none +b bf16 r n n n r 481 20 2050 2050 20 20 none +b f32 bf16 r n n n r 481 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n r 481 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n r 482 20 2050 2050 20 20 none +b bf16 r n n n r 482 20 2050 2050 20 20 none +b f32 bf16 r n n n r 482 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n r 482 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n p 483 20 2050 2050 20 20 none +b bf16 r n n n p 483 20 2050 2050 20 20 none +b f32 bf16 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n R 484 20 2050 2050 20 20 none +b bf16 r n n n R 484 20 2050 2050 20 20 none +b f32 bf16 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n R 485 20 2050 2050 20 20 none +b bf16 r n n n R 485 20 2050 2050 20 20 none +b f32 bf16 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +b bf16 bf16 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +b f32 r n n n R 480 39 2050 2050 39 39 none +b bf16 r n n n R 480 39 2050 2050 39 39 none +b f32 bf16 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n R 481 39 2050 2050 39 39 none +b bf16 r n n n R 481 39 2050 2050 39 39 none +b f32 bf16 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n R 482 39 2050 2050 39 39 none +b bf16 r n n n R 482 39 2050 2050 39 39 none +b f32 bf16 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n R 483 39 2050 2050 39 39 none +b bf16 r n n n R 483 39 2050 2050 39 39 none +b f32 bf16 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n R 484 39 2050 2050 39 39 none +b bf16 r n n n R 484 39 2050 2050 39 39 none +b f32 bf16 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n p 485 39 2050 2050 39 39 none +b bf16 r n n n p 485 39 2050 2050 39 39 none +b f32 bf16 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +b bf16 bf16 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +b f32 r n n n p 480 50 2050 2050 50 50 none +b bf16 r n n n p 480 50 2050 2050 50 50 none +b f32 bf16 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n p 481 50 2050 2050 50 50 none +b bf16 r n n n p 481 50 2050 2050 50 50 none +b f32 bf16 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n p 482 50 2050 2050 50 50 none +b bf16 r n n n p 482 50 2050 2050 50 50 none +b f32 bf16 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n p 483 50 2050 2050 50 50 none +b bf16 r n n n p 483 50 2050 2050 50 50 none +b f32 bf16 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n p 484 50 2050 2050 50 50 none +b bf16 r n n n p 484 50 2050 2050 50 50 none +b f32 bf16 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n p 485 50 2050 2050 50 50 none +b bf16 r n n n p 485 50 2050 2050 50 50 none +b f32 bf16 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +b bf16 bf16 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +b f32 r n n n R 480 1108 2050 2050 1108 1108 none +b bf16 r n n n R 480 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 481 1108 2050 2050 1108 1108 none +b bf16 r n n n R 481 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 482 1108 2050 2050 1108 1108 none +b bf16 r n n n R 482 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 483 1108 2050 2050 1108 1108 none +b bf16 r n n n R 483 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 484 1108 2050 2050 1108 1108 none +b bf16 r n n n R 484 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 485 1108 2050 2050 1108 1108 none +b bf16 r n n n R 485 1108 2050 2050 1108 1108 none +b f32 bf16 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +b bf16 bf16 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +b f32 r n n n R 480 1127 2050 2050 1127 1127 none +b bf16 r n n n R 480 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n R 481 1127 2050 2050 1127 1127 none +b bf16 r n n n R 481 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n R 482 1127 2050 2050 1127 1127 none +b bf16 r n n n R 482 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n R 483 1127 2050 2050 1127 1127 none +b bf16 r n n n R 483 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n p 484 1127 2050 2050 1127 1127 none +b bf16 r n n n p 484 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n p 485 1127 2050 2050 1127 1127 none +b bf16 r n n n p 485 1127 2050 2050 1127 1127 none +b f32 bf16 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +b bf16 bf16 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +b f32 r n n n p 480 1138 2050 2050 1138 1138 none +b bf16 r n n n p 480 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 481 1138 2050 2050 1138 1138 none +b bf16 r n n n p 481 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 482 1138 2050 2050 1138 1138 none +b bf16 r n n n p 482 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 483 1138 2050 2050 1138 1138 none +b bf16 r n n n p 483 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 484 1138 2050 2050 1138 1138 none +b bf16 r n n n p 484 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 485 1138 2050 2050 1138 1138 none +b bf16 r n n n p 485 1138 2050 2050 1138 1138 none +b f32 bf16 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +b bf16 bf16 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +b f32 r n n n p 1 1 3 3 1 1 none +b bf16 r n n n p 1 1 3 3 1 1 none +b f32 bf16 r n n n p 1 1 3 3 1 1 bias,relu,clip +b bf16 bf16 r n n n p 1 1 3 3 1 1 bias,relu,clip +b f32 r n n n p 1 9 3 3 9 9 none +b bf16 r n n n p 1 9 3 3 9 9 none +b f32 bf16 r n n n p 1 9 3 3 9 9 bias,relu,clip +b bf16 bf16 r n n n p 1 9 3 3 9 9 bias,relu,clip +b f32 r n n n p 1 2048 3 3 2048 2048 none +b bf16 r n n n p 1 2048 3 3 2048 2048 none +b f32 bf16 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +b f32 r n n n p 1 2048 5192 5192 2048 2048 none +b bf16 r n n n p 1 2048 5192 5192 2048 2048 none +b f32 bf16 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +b f32 r n n n p 9 1 3 3 1 1 none +b bf16 r n n n p 9 1 3 3 1 1 none +b f32 bf16 r n n n p 9 1 3 3 1 1 bias,relu,clip +b bf16 bf16 r n n n p 9 1 3 3 1 1 bias,relu,clip +b f32 r n n n p 576 1 3500 3500 1 1 none +b bf16 r n n n p 576 1 3500 3500 1 1 none +b f32 bf16 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +b bf16 bf16 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +b f32 r n n n p 1 1 1 1 1 1 none +b bf16 r n n n p 1 1 1 1 1 1 none +b f32 bf16 r n n n p 1 1 1 1 1 1 bias,relu,clip +b bf16 bf16 r n n n p 1 1 1 1 1 1 bias,relu,clip +b f32 r n n n p 102 1088 1024 1024 1088 1088 none +b bf16 r n n n p 102 1088 1024 1024 1088 1088 none +b f32 bf16 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +b bf16 bf16 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +b f32 r n n n p 102 2048 1024 1024 2048 2048 none +b bf16 r n n n p 102 2048 1024 1024 2048 2048 none +b f32 bf16 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +b f32 r n n n p 485 656 1024 1024 656 656 none +b bf16 r n n n p 485 656 1024 1024 656 656 none +b f32 bf16 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +b bf16 bf16 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +b f32 r n n n p 483 656 1024 1024 656 656 none +b bf16 r n n n p 483 656 1024 1024 656 656 none +b f32 bf16 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +b bf16 bf16 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +b f32 r n n n p 81 128 3 3 128 128 none +b bf16 r n n n p 81 128 3 3 128 128 none +b f32 bf16 r n n n p 81 128 3 3 128 128 bias,relu,clip +b bf16 bf16 r n n n p 81 128 3 3 128 128 bias,relu,clip +b f32 r n n n p 1022 512 515 515 512 512 none +b bf16 r n n n p 1022 512 515 515 512 512 none +b f32 bf16 r n n n p 1022 512 515 515 512 512 bias,relu,clip +b bf16 bf16 r n n n p 1022 512 515 515 512 512 bias,relu,clip +b f32 r n n n p 74 512 515 515 512 512 none +b bf16 r n n n p 74 512 515 515 512 512 none +b f32 bf16 r n n n p 74 512 515 515 512 512 bias,relu,clip +b bf16 bf16 r n n n p 74 512 515 515 512 512 bias,relu,clip +b f32 r n n n p 253 2048 515 515 2048 2048 none +b bf16 r n n n p 253 2048 515 515 2048 2048 none +b f32 bf16 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +b f32 r n n n p 8192 1040 515 515 1040 1040 none +b bf16 r n n n p 8192 1040 515 515 1040 1040 none +b f32 bf16 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +b bf16 bf16 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +b f32 r n n n p 10 1029 515 515 1029 1029 none +b bf16 r n n n p 10 1029 515 515 1029 1029 none +b f32 bf16 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +b bf16 bf16 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +b f32 r n n n p 24 1040 2050 2050 1040 1040 none +b bf16 r n n n p 24 1040 2050 2050 1040 1040 none +b f32 bf16 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +b bf16 bf16 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +b f32 r n n n p 1024 1029 2050 2050 1029 1029 none +b bf16 r n n n p 1024 1029 2050 2050 1029 1029 none +b f32 bf16 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +b bf16 bf16 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +b f32 r n n n p 480 660 2050 2050 660 660 none +b bf16 r n n n p 480 660 2050 2050 660 660 none +b f32 bf16 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 481 660 2050 2050 660 660 none +b bf16 r n n n p 481 660 2050 2050 660 660 none +b f32 bf16 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 482 660 2050 2050 660 660 none +b bf16 r n n n p 482 660 2050 2050 660 660 none +b f32 bf16 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 483 660 2050 2050 660 660 none +b bf16 r n n n p 483 660 2050 2050 660 660 none +b f32 bf16 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 484 660 2050 2050 660 660 none +b bf16 r n n n p 484 660 2050 2050 660 660 none +b f32 bf16 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 485 660 2050 2050 660 660 none +b bf16 r n n n p 485 660 2050 2050 660 660 none +b f32 bf16 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +b bf16 bf16 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +b f32 r n n n p 480 679 2050 2050 679 679 none +b bf16 r n n n p 480 679 2050 2050 679 679 none +b f32 bf16 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 481 679 2050 2050 679 679 none +b bf16 r n n n p 481 679 2050 2050 679 679 none +b f32 bf16 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 482 679 2050 2050 679 679 none +b bf16 r n n n p 482 679 2050 2050 679 679 none +b f32 bf16 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 483 679 2050 2050 679 679 none +b bf16 r n n n p 483 679 2050 2050 679 679 none +b f32 bf16 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 484 679 2050 2050 679 679 none +b bf16 r n n n p 484 679 2050 2050 679 679 none +b f32 bf16 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 485 679 2050 2050 679 679 none +b bf16 r n n n p 485 679 2050 2050 679 679 none +b f32 bf16 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +b bf16 bf16 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +b f32 r n n n p 480 690 2050 2050 690 690 none +b bf16 r n n n p 480 690 2050 2050 690 690 none +b f32 bf16 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 481 690 2050 2050 690 690 none +b bf16 r n n n p 481 690 2050 2050 690 690 none +b f32 bf16 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 482 690 2050 2050 690 690 none +b bf16 r n n n p 482 690 2050 2050 690 690 none +b f32 bf16 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 483 690 2050 2050 690 690 none +b bf16 r n n n p 483 690 2050 2050 690 690 none +b f32 bf16 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 484 690 2050 2050 690 690 none +b bf16 r n n n p 484 690 2050 2050 690 690 none +b f32 bf16 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 485 690 2050 2050 690 690 none +b bf16 r n n n p 485 690 2050 2050 690 690 none +b f32 bf16 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +b bf16 bf16 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +b f32 r n n n p 480 660 2048 2048 660 660 none +b bf16 r n n n p 480 660 2048 2048 660 660 none +b f32 bf16 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 481 660 2048 2048 660 660 none +b bf16 r n n n p 481 660 2048 2048 660 660 none +b f32 bf16 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 482 660 2048 2048 660 660 none +b bf16 r n n n p 482 660 2048 2048 660 660 none +b f32 bf16 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 483 660 2048 2048 660 660 none +b bf16 r n n n p 483 660 2048 2048 660 660 none +b f32 bf16 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 484 660 2048 2048 660 660 none +b bf16 r n n n p 484 660 2048 2048 660 660 none +b f32 bf16 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 485 660 2048 2048 660 660 none +b bf16 r n n n p 485 660 2048 2048 660 660 none +b f32 bf16 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +b bf16 bf16 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +b f32 r n n n p 480 679 2048 2048 679 679 none +b bf16 r n n n p 480 679 2048 2048 679 679 none +b f32 bf16 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 481 679 2048 2048 679 679 none +b bf16 r n n n p 481 679 2048 2048 679 679 none +b f32 bf16 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 482 679 2048 2048 679 679 none +b bf16 r n n n p 482 679 2048 2048 679 679 none +b f32 bf16 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 483 679 2048 2048 679 679 none +b bf16 r n n n p 483 679 2048 2048 679 679 none +b f32 bf16 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 484 679 2048 2048 679 679 none +b bf16 r n n n p 484 679 2048 2048 679 679 none +b f32 bf16 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 485 679 2048 2048 679 679 none +b bf16 r n n n p 485 679 2048 2048 679 679 none +b f32 bf16 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +b bf16 bf16 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +b f32 r n n n p 480 690 2048 2048 690 690 none +b bf16 r n n n p 480 690 2048 2048 690 690 none +b f32 bf16 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 481 690 2048 2048 690 690 none +b bf16 r n n n p 481 690 2048 2048 690 690 none +b f32 bf16 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 482 690 2048 2048 690 690 none +b bf16 r n n n p 482 690 2048 2048 690 690 none +b f32 bf16 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 483 690 2048 2048 690 690 none +b bf16 r n n n p 483 690 2048 2048 690 690 none +b f32 bf16 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 484 690 2048 2048 690 690 none +b bf16 r n n n p 484 690 2048 2048 690 690 none +b f32 bf16 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 485 690 2048 2048 690 690 none +b bf16 r n n n p 485 690 2048 2048 690 690 none +b f32 bf16 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +b bf16 bf16 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +b f32 r n n n p 480 656 1024 1024 656 656 none +b bf16 r n n n p 480 656 1024 1024 656 656 none +b f32 bf16 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +b bf16 bf16 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +b f32 r n n n p 480 128 3 3 128 128 none +b bf16 r n n n p 480 128 3 3 128 128 none +b f32 bf16 r n n n p 480 128 3 3 128 128 bias,relu,clip +b bf16 bf16 r n n n p 480 128 3 3 128 128 bias,relu,clip +b f32 r n n n p 1024 512 515 515 512 512 none +b bf16 r n n n p 1024 512 515 515 512 512 none +b f32 bf16 r n n n p 1024 512 515 515 512 512 bias,relu,clip +b bf16 bf16 r n n n p 1024 512 515 515 512 512 bias,relu,clip +b f32 r n n n p 1024 2048 1024 1024 2048 2048 none +b bf16 r n n n p 1024 2048 1024 1024 2048 2048 none +b f32 bf16 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +b f32 r n n n p 1024 2048 515 515 2048 2048 none +b bf16 r n n n p 1024 2048 515 515 2048 2048 none +b f32 bf16 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +b bf16 bf16 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +b f32 r n n n p 1024 1040 515 515 1040 1040 none +b bf16 r n n n p 1024 1040 515 515 1040 1040 none +b f32 bf16 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +b bf16 bf16 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +b f32 r n n n p 5 1029 515 515 1029 1029 none +b bf16 r n n n p 5 1029 515 515 1029 1029 none +b f32 bf16 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +b bf16 bf16 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +b f32 r n n n p 1024 1029 515 515 1029 1029 none +b bf16 r n n n p 1024 1029 515 515 1029 1029 none +b f32 bf16 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +b bf16 bf16 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +b f32 r n n n p 1024 1040 2050 2050 1040 1040 none +b bf16 r n n n p 1024 1040 2050 2050 1040 1040 none +b f32 bf16 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +b bf16 bf16 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +b f32 r n n n p 1029 1029 2050 2050 1029 1029 none +b bf16 r n n n p 1029 1029 2050 2050 1029 1029 none +b f32 bf16 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +b bf16 bf16 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +b f32 r n n n R 480 646 2050 2050 646 646 none +b bf16 r n n n R 480 646 2050 2050 646 646 none +b f32 bf16 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 481 646 2050 2050 646 646 none +b bf16 r n n n R 481 646 2050 2050 646 646 none +b f32 bf16 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 482 646 2050 2050 646 646 none +b bf16 r n n n R 482 646 2050 2050 646 646 none +b f32 bf16 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 483 646 2050 2050 646 646 none +b bf16 r n n n R 483 646 2050 2050 646 646 none +b f32 bf16 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 484 646 2050 2050 646 646 none +b bf16 r n n n R 484 646 2050 2050 646 646 none +b f32 bf16 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 485 646 2050 2050 646 646 none +b bf16 r n n n R 485 646 2050 2050 646 646 none +b f32 bf16 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +b bf16 bf16 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +b f32 r n n n R 481 656 2050 2050 656 656 none +b bf16 r n n n R 481 656 2050 2050 656 656 none +b f32 bf16 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +b bf16 bf16 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +b f32 r n n n R 482 656 2050 2050 656 656 none +b bf16 r n n n R 482 656 2050 2050 656 656 none +b f32 bf16 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +b bf16 bf16 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +b f32 r n n n R 483 656 2050 2050 656 656 none +b bf16 r n n n R 483 656 2050 2050 656 656 none +b f32 bf16 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +b bf16 bf16 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +b f32 r n n n R 484 656 2050 2050 656 656 none +b bf16 r n n n R 484 656 2050 2050 656 656 none +b f32 bf16 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +b bf16 bf16 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +b f32 r n n n p 485 656 2050 2050 656 656 none +b bf16 r n n n p 485 656 2050 2050 656 656 none +b f32 bf16 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +b bf16 bf16 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +b f32 r n n n p 480 672 2050 2050 672 672 none +b bf16 r n n n p 480 672 2050 2050 672 672 none +b f32 bf16 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 481 672 2050 2050 672 672 none +b bf16 r n n n p 481 672 2050 2050 672 672 none +b f32 bf16 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 482 672 2050 2050 672 672 none +b bf16 r n n n p 482 672 2050 2050 672 672 none +b f32 bf16 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 483 672 2050 2050 672 672 none +b bf16 r n n n p 483 672 2050 2050 672 672 none +b f32 bf16 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 484 672 2050 2050 672 672 none +b bf16 r n n n p 484 672 2050 2050 672 672 none +b f32 bf16 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 485 672 2050 2050 672 672 none +b bf16 r n n n p 485 672 2050 2050 672 672 none +b f32 bf16 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +b bf16 bf16 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +b f32 r n n n p 480 688 2050 2050 688 688 none +b bf16 r n n n p 480 688 2050 2050 688 688 none +b f32 bf16 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n p 481 688 2050 2050 688 688 none +b bf16 r n n n p 481 688 2050 2050 688 688 none +b f32 bf16 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n r 482 688 2050 2050 688 688 none +b bf16 r n n n r 482 688 2050 2050 688 688 none +b f32 bf16 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n r 483 688 2050 2050 688 688 none +b bf16 r n n n r 483 688 2050 2050 688 688 none +b f32 bf16 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n r 484 688 2050 2050 688 688 none +b bf16 r n n n r 484 688 2050 2050 688 688 none +b f32 bf16 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n r 485 688 2050 2050 688 688 none +b bf16 r n n n r 485 688 2050 2050 688 688 none +b f32 bf16 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +b bf16 bf16 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +b f32 r n n n r 1024 512 64 64 512 512 none +b bf16 r n n n r 1024 512 64 64 512 512 none +b f32 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip +b bf16 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip +b f32 r n n n r 16 256 512 512 256 256 none +b bf16 r n n n r 16 256 512 512 256 256 none +b f32 bf16 r n n n r 16 256 512 512 256 256 bias,relu,clip +b bf16 bf16 r n n n r 16 256 512 512 256 256 bias,relu,clip +b f32 r n n n r 480 640 512 512 640 640 none +b bf16 r n n n r 480 640 512 512 640 640 none +b f32 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip +b bf16 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip +b f32 r n n n r 64 768 512 512 768 768 none +b bf16 r n n n r 64 768 512 512 768 768 none +b f32 bf16 r n n n r 64 768 512 512 768 768 bias,relu,clip +b bf16 bf16 r n n n r 64 768 512 512 768 768 bias,relu,clip +b f32 r n n n r 128 128 128 128 128 128 none +b bf16 r n n n r 128 128 128 128 128 128 none +b f32 bf16 r n n n r 128 128 128 128 128 128 bias,relu,clip +b bf16 bf16 r n n n r 128 128 128 128 128 128 bias,relu,clip +b f32 r n n n r 1024 64 512 512 64 64 none +b bf16 r n n n r 1024 64 512 512 64 64 none +b f32 bf16 r n n n r 1024 64 512 512 64 64 bias,relu,clip +b bf16 bf16 r n n n r 1024 64 512 512 64 64 bias,relu,clip +b f32 r n n n r 1024 256 32 32 256 256 none +b bf16 r n n n r 1024 256 32 32 256 256 none +b f32 bf16 r n n n r 1024 256 32 32 256 256 bias,relu,clip +b bf16 bf16 r n n n r 1024 256 32 32 256 256 bias,relu,clip +b f32 r n n n r 1024 512 64 64 512 512 none +b bf16 r n n n r 1024 512 64 64 512 512 none +b f32 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip +b bf16 bf16 r n n n r 1024 512 64 64 512 512 bias,relu,clip +b f32 r n n n r 480 640 512 512 640 640 none +b bf16 r n n n r 480 640 512 512 640 640 none +b f32 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip +b bf16 bf16 r n n n r 480 640 512 512 640 640 bias,relu,clip +b f32 r n n n p 1024 32 256 256 32 32 none +b bf16 r n n n p 1024 32 256 256 32 32 none +b f32 bf16 r n n n p 1024 32 256 256 32 32 bias,relu,clip +b bf16 bf16 r n n n p 1024 32 256 256 32 32 bias,relu,clip +b f32 r n n n P 1024 64 512 512 64 64 none +b bf16 r n n n P 1024 64 512 512 64 64 none +b f32 bf16 r n n n P 1024 64 512 512 64 64 bias,relu,clip +b bf16 bf16 r n n n P 1024 64 512 512 64 64 bias,relu,clip +b f32 r n n n P 64 800 320 320 800 800 none +b bf16 r n n n P 64 800 320 320 800 800 none +b f32 bf16 r n n n P 64 800 320 320 800 800 bias,relu,clip +b bf16 bf16 r n n n P 64 800 320 320 800 800 bias,relu,clip +b f32 r n n n P 64 768 512 512 768 768 none +b bf16 r n n n P 64 768 512 512 768 768 none +b f32 bf16 r n n n P 64 768 512 512 768 768 bias,relu,clip +b bf16 bf16 r n n n P 64 768 512 512 768 768 bias,relu,clip +b f32 r n n n P 16 256 512 512 256 256 none +b bf16 r n n n P 16 256 512 512 256 256 none +b f32 bf16 r n n n P 16 256 512 512 256 256 bias,relu,clip +b bf16 bf16 r n n n P 16 256 512 512 256 256 bias,relu,clip +b f32 r n n n P 128 128 128 128 128 128 none +b bf16 r n n n P 128 128 128 128 128 128 none +b f32 bf16 r n n n P 128 128 128 128 128 128 bias,relu,clip +b bf16 bf16 r n n n P 128 128 128 128 128 128 bias,relu,clip +b f32 r n n n P 256 512 256 256 512 512 none +b bf16 r n n n P 256 512 256 256 512 512 none +b f32 bf16 r n n n P 256 512 256 256 512 512 bias,relu,clip +b bf16 bf16 r n n n P 256 512 256 256 512 512 bias,relu,clip +b f32 r n n n P 1024 1024 1024 1024 1024 1024 none +b bf16 r n n n P 1024 1024 1024 1024 1024 1024 none +b f32 bf16 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +b bf16 bf16 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +b f32 r n n n P 480 640 1024 1024 640 640 none +b bf16 r n n n P 480 640 1024 1024 640 640 none +b f32 bf16 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +b bf16 bf16 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +b f32 r n n n P 480 640 256 256 640 640 none +b bf16 r n n n P 480 640 256 256 640 640 none +b f32 bf16 r n n n P 480 640 256 256 640 640 bias,relu,clip +b bf16 bf16 r n n n P 480 640 256 256 640 640 bias,relu,clip +b f32 r n n n P 8 64 32 32 64 64 none +b bf16 r n n n P 8 64 32 32 64 64 none +b f32 bf16 r n n n P 8 64 32 32 64 64 bias,relu,clip +b bf16 bf16 r n n n P 8 64 32 32 64 64 bias,relu,clip +b f32 r n n n P 9 64 32 32 64 64 none +b bf16 r n n n P 9 64 32 32 64 64 none +b f32 bf16 r n n n P 9 64 32 32 64 64 bias,relu,clip +b bf16 bf16 r n n n P 9 64 32 32 64 64 bias,relu,clip +b f32 r n n n P 10 128 64 64 128 128 none +b bf16 r n n n P 10 128 64 64 128 128 none +b f32 bf16 r n n n P 10 128 64 64 128 128 bias,relu,clip +b bf16 bf16 r n n n P 10 128 64 64 128 128 bias,relu,clip +b f32 r n n n P 8 8 8 8 8 8 none +b bf16 r n n n P 8 8 8 8 8 8 none +b f32 bf16 r n n n P 8 8 8 8 8 8 bias,relu,clip +b bf16 bf16 r n n n P 8 8 8 8 8 8 bias,relu,clip +b f32 r n n n P 12 12 12 12 12 12 none +b bf16 r n n n P 12 12 12 12 12 12 none +b f32 bf16 r n n n P 12 12 12 12 12 12 bias,relu,clip +b bf16 bf16 r n n n P 12 12 12 12 12 12 bias,relu,clip +b f32 r n n n P 25 25 25 25 25 25 none +b bf16 r n n n P 25 25 25 25 25 25 none +b f32 bf16 r n n n P 25 25 25 25 25 25 bias,relu,clip +b bf16 bf16 r n n n P 25 25 25 25 25 25 bias,relu,clip +b f32 r n n n P 25 25 20 20 25 25 none +b bf16 r n n n P 25 25 20 20 25 25 none +b f32 bf16 r n n n P 25 25 20 20 25 25 bias,relu,clip +b bf16 bf16 r n n n P 25 25 20 20 25 25 bias,relu,clip +b f32 c n n n p 485 39 2050 485 2050 485 none +b bf16 c n n n p 485 39 2050 485 2050 485 none +b f32 bf16 c n n n p 485 39 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 39 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 480 50 2050 480 2050 480 none +b bf16 c n n n p 480 50 2050 480 2050 480 none +b f32 bf16 c n n n p 480 50 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c n n n p 480 50 2050 480 2050 480 bias,relu,clip +b f32 c n n n p 481 50 2050 481 2050 481 none +b bf16 c n n n p 481 50 2050 481 2050 481 none +b f32 bf16 c n n n p 481 50 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c n n n p 481 50 2050 481 2050 481 bias,relu,clip +b f32 c n n n p 482 50 2050 482 2050 482 none +b bf16 c n n n p 482 50 2050 482 2050 482 none +b f32 bf16 c n n n p 482 50 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c n n n p 482 50 2050 482 2050 482 bias,relu,clip +b f32 c n n n p 483 50 2050 483 2050 483 none +b bf16 c n n n p 483 50 2050 483 2050 483 none +b f32 bf16 c n n n p 483 50 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c n n n p 483 50 2050 483 2050 483 bias,relu,clip +b f32 c n n n p 484 50 2050 484 2050 484 none +b bf16 c n n n p 484 50 2050 484 2050 484 none +b f32 bf16 c n n n p 484 50 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 50 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 50 2050 485 2050 485 none +b bf16 c n n n p 485 50 2050 485 2050 485 none +b f32 bf16 c n n n p 485 50 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 50 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 484 1127 2050 484 2050 484 none +b bf16 c n n n p 484 1127 2050 484 2050 484 none +b f32 bf16 c n n n p 484 1127 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 1127 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 1127 2050 485 2050 485 none +b bf16 c n n n p 485 1127 2050 485 2050 485 none +b f32 bf16 c n n n p 485 1127 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 1127 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 480 1138 2050 480 2050 480 none +b bf16 c n n n p 480 1138 2050 480 2050 480 none +b f32 bf16 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c n n n p 480 1138 2050 480 2050 480 bias,relu,clip +b f32 c n n n p 481 1138 2050 481 2050 481 none +b bf16 c n n n p 481 1138 2050 481 2050 481 none +b f32 bf16 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c n n n p 481 1138 2050 481 2050 481 bias,relu,clip +b f32 c n n n p 482 1138 2050 482 2050 482 none +b bf16 c n n n p 482 1138 2050 482 2050 482 none +b f32 bf16 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c n n n p 482 1138 2050 482 2050 482 bias,relu,clip +b f32 c n n n p 483 1138 2050 483 2050 483 none +b bf16 c n n n p 483 1138 2050 483 2050 483 none +b f32 bf16 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c n n n p 483 1138 2050 483 2050 483 bias,relu,clip +b f32 c n n n p 484 1138 2050 484 2050 484 none +b bf16 c n n n p 484 1138 2050 484 2050 484 none +b f32 bf16 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 1138 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 1138 2050 485 2050 485 none +b bf16 c n n n p 485 1138 2050 485 2050 485 none +b f32 bf16 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 1138 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 1 1 3 1 3 1 none +b bf16 c n n n p 1 1 3 1 3 1 none +b f32 bf16 c n n n p 1 1 3 1 3 1 bias,relu,clip +b bf16 bf16 c n n n p 1 1 3 1 3 1 bias,relu,clip +b f32 c n n n p 1 9 3 1 3 1 none +b bf16 c n n n p 1 9 3 1 3 1 none +b f32 bf16 c n n n p 1 9 3 1 3 1 bias,relu,clip +b bf16 bf16 c n n n p 1 9 3 1 3 1 bias,relu,clip +b f32 c n n n p 1 2048 3 1 3 1 none +b bf16 c n n n p 1 2048 3 1 3 1 none +b f32 bf16 c n n n p 1 2048 3 1 3 1 bias,relu,clip +b bf16 bf16 c n n n p 1 2048 3 1 3 1 bias,relu,clip +b f32 c n n n p 1 2048 5192 1 5192 1 none +b bf16 c n n n p 1 2048 5192 1 5192 1 none +b f32 bf16 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip +b bf16 bf16 c n n n p 1 2048 5192 1 5192 1 bias,relu,clip +b f32 c n n n p 9 1 3 9 3 9 none +b bf16 c n n n p 9 1 3 9 3 9 none +b f32 bf16 c n n n p 9 1 3 9 3 9 bias,relu,clip +b bf16 bf16 c n n n p 9 1 3 9 3 9 bias,relu,clip +b f32 c n n n p 576 1 3500 576 3500 576 none +b bf16 c n n n p 576 1 3500 576 3500 576 none +b f32 bf16 c n n n p 576 1 3500 576 3500 576 bias,relu,clip +b bf16 bf16 c n n n p 576 1 3500 576 3500 576 bias,relu,clip +b f32 c n n n p 1 1 1 1 1 1 none +b bf16 c n n n p 1 1 1 1 1 1 none +b f32 bf16 c n n n p 1 1 1 1 1 1 bias,relu,clip +b bf16 bf16 c n n n p 1 1 1 1 1 1 bias,relu,clip +b f32 c n n n p 102 1088 1024 102 1024 102 none +b bf16 c n n n p 102 1088 1024 102 1024 102 none +b f32 bf16 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip +b bf16 bf16 c n n n p 102 1088 1024 102 1024 102 bias,relu,clip +b f32 c n n n p 102 2048 1024 102 1024 102 none +b bf16 c n n n p 102 2048 1024 102 1024 102 none +b f32 bf16 c n n n p 102 2048 1024 102 1024 102 bias,relu,clip +b bf16 bf16 c n n n p 102 2048 1024 102 1024 102 bias,relu,clip +b f32 c n n n p 485 656 1024 485 1024 485 none +b bf16 c n n n p 485 656 1024 485 1024 485 none +b f32 bf16 c n n n p 485 656 1024 485 1024 485 bias,relu,clip +b bf16 bf16 c n n n p 485 656 1024 485 1024 485 bias,relu,clip +b f32 c n n n p 483 656 1024 483 1024 483 none +b bf16 c n n n p 483 656 1024 483 1024 483 none +b f32 bf16 c n n n p 483 656 1024 483 1024 483 bias,relu,clip +b bf16 bf16 c n n n p 483 656 1024 483 1024 483 bias,relu,clip +b f32 c n n n p 81 128 3 81 3 81 none +b bf16 c n n n p 81 128 3 81 3 81 none +b f32 bf16 c n n n p 81 128 3 81 3 81 bias,relu,clip +b bf16 bf16 c n n n p 81 128 3 81 3 81 bias,relu,clip +b f32 c n n n p 1022 512 515 1022 515 1022 none +b bf16 c n n n p 1022 512 515 1022 515 1022 none +b f32 bf16 c n n n p 1022 512 515 1022 515 1022 bias,relu,clip +b bf16 bf16 c n n n p 1022 512 515 1022 515 1022 bias,relu,clip +b f32 c n n n p 74 512 515 74 515 74 none +b bf16 c n n n p 74 512 515 74 515 74 none +b f32 bf16 c n n n p 74 512 515 74 515 74 bias,relu,clip +b bf16 bf16 c n n n p 74 512 515 74 515 74 bias,relu,clip +b f32 c n n n p 253 2048 515 253 515 253 none +b bf16 c n n n p 253 2048 515 253 515 253 none +b f32 bf16 c n n n p 253 2048 515 253 515 253 bias,relu,clip +b bf16 bf16 c n n n p 253 2048 515 253 515 253 bias,relu,clip +b f32 c n n n p 8192 1040 515 8192 515 8192 none +b bf16 c n n n p 8192 1040 515 8192 515 8192 none +b f32 bf16 c n n n p 8192 1040 515 8192 515 8192 bias,relu,clip +b bf16 bf16 c n n n p 8192 1040 515 8192 515 8192 bias,relu,clip +b f32 c n n n p 10 1029 515 10 515 10 none +b bf16 c n n n p 10 1029 515 10 515 10 none +b f32 bf16 c n n n p 10 1029 515 10 515 10 bias,relu,clip +b bf16 bf16 c n n n p 10 1029 515 10 515 10 bias,relu,clip +b f32 c n n n p 24 1040 2050 24 2050 24 none +b bf16 c n n n p 24 1040 2050 24 2050 24 none +b f32 bf16 c n n n p 24 1040 2050 24 2050 24 bias,relu,clip +b bf16 bf16 c n n n p 24 1040 2050 24 2050 24 bias,relu,clip +b f32 c n n n p 1024 1029 2050 1024 2050 1024 none +b bf16 c n n n p 1024 1029 2050 1024 2050 1024 none +b f32 bf16 c n n n p 1024 1029 2050 1024 2050 1024 bias,relu,clip +b bf16 bf16 c n n n p 1024 1029 2050 1024 2050 1024 bias,relu,clip +b f32 c n n n p 480 660 2050 480 2050 480 none +b bf16 c n n n p 480 660 2050 480 2050 480 none +b f32 bf16 c n n n p 480 660 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c n n n p 480 660 2050 480 2050 480 bias,relu,clip +b f32 c n n n p 481 660 2050 481 2050 481 none +b bf16 c n n n p 481 660 2050 481 2050 481 none +b f32 bf16 c n n n p 481 660 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c n n n p 481 660 2050 481 2050 481 bias,relu,clip +b f32 c n n n p 482 660 2050 482 2050 482 none +b bf16 c n n n p 482 660 2050 482 2050 482 none +b f32 bf16 c n n n p 482 660 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c n n n p 482 660 2050 482 2050 482 bias,relu,clip +b f32 c n n n p 483 660 2050 483 2050 483 none +b bf16 c n n n p 483 660 2050 483 2050 483 none +b f32 bf16 c n n n p 483 660 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c n n n p 483 660 2050 483 2050 483 bias,relu,clip +b f32 c n n n p 484 660 2050 484 2050 484 none +b bf16 c n n n p 484 660 2050 484 2050 484 none +b f32 bf16 c n n n p 484 660 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 660 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 660 2050 485 2050 485 none +b bf16 c n n n p 485 660 2050 485 2050 485 none +b f32 bf16 c n n n p 485 660 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 660 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 480 679 2050 480 2050 480 none +b bf16 c n n n p 480 679 2050 480 2050 480 none +b f32 bf16 c n n n p 480 679 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c n n n p 480 679 2050 480 2050 480 bias,relu,clip +b f32 c n n n p 481 679 2050 481 2050 481 none +b bf16 c n n n p 481 679 2050 481 2050 481 none +b f32 bf16 c n n n p 481 679 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c n n n p 481 679 2050 481 2050 481 bias,relu,clip +b f32 c n n n p 482 679 2050 482 2050 482 none +b bf16 c n n n p 482 679 2050 482 2050 482 none +b f32 bf16 c n n n p 482 679 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c n n n p 482 679 2050 482 2050 482 bias,relu,clip +b f32 c n n n p 483 679 2050 483 2050 483 none +b bf16 c n n n p 483 679 2050 483 2050 483 none +b f32 bf16 c n n n p 483 679 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c n n n p 483 679 2050 483 2050 483 bias,relu,clip +b f32 c n n n p 484 679 2050 484 2050 484 none +b bf16 c n n n p 484 679 2050 484 2050 484 none +b f32 bf16 c n n n p 484 679 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 679 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 679 2050 485 2050 485 none +b bf16 c n n n p 485 679 2050 485 2050 485 none +b f32 bf16 c n n n p 485 679 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 679 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 480 690 2050 480 2050 480 none +b bf16 c n n n p 480 690 2050 480 2050 480 none +b f32 bf16 c n n n p 480 690 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c n n n p 480 690 2050 480 2050 480 bias,relu,clip +b f32 c n n n p 481 690 2050 481 2050 481 none +b bf16 c n n n p 481 690 2050 481 2050 481 none +b f32 bf16 c n n n p 481 690 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c n n n p 481 690 2050 481 2050 481 bias,relu,clip +b f32 c n n n p 482 690 2050 482 2050 482 none +b bf16 c n n n p 482 690 2050 482 2050 482 none +b f32 bf16 c n n n p 482 690 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c n n n p 482 690 2050 482 2050 482 bias,relu,clip +b f32 c n n n p 483 690 2050 483 2050 483 none +b bf16 c n n n p 483 690 2050 483 2050 483 none +b f32 bf16 c n n n p 483 690 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c n n n p 483 690 2050 483 2050 483 bias,relu,clip +b f32 c n n n p 484 690 2050 484 2050 484 none +b bf16 c n n n p 484 690 2050 484 2050 484 none +b f32 bf16 c n n n p 484 690 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c n n n p 484 690 2050 484 2050 484 bias,relu,clip +b f32 c n n n p 485 690 2050 485 2050 485 none +b bf16 c n n n p 485 690 2050 485 2050 485 none +b f32 bf16 c n n n p 485 690 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c n n n p 485 690 2050 485 2050 485 bias,relu,clip +b f32 c n n n p 480 660 2048 480 2048 480 none +b bf16 c n n n p 480 660 2048 480 2048 480 none +b f32 bf16 c n n n p 480 660 2048 480 2048 480 bias,relu,clip +b bf16 bf16 c n n n p 480 660 2048 480 2048 480 bias,relu,clip +b f32 c n n n p 481 660 2048 481 2048 481 none +b bf16 c n n n p 481 660 2048 481 2048 481 none +b f32 bf16 c n n n p 481 660 2048 481 2048 481 bias,relu,clip +b bf16 bf16 c n n n p 481 660 2048 481 2048 481 bias,relu,clip +b f32 c n n n p 482 660 2048 482 2048 482 none +b bf16 c n n n p 482 660 2048 482 2048 482 none +b f32 bf16 c n n n p 482 660 2048 482 2048 482 bias,relu,clip +b bf16 bf16 c n n n p 482 660 2048 482 2048 482 bias,relu,clip +b f32 c n n n p 483 660 2048 483 2048 483 none +b bf16 c n n n p 483 660 2048 483 2048 483 none +b f32 bf16 c n n n p 483 660 2048 483 2048 483 bias,relu,clip +b bf16 bf16 c n n n p 483 660 2048 483 2048 483 bias,relu,clip +b f32 c n n n p 484 660 2048 484 2048 484 none +b bf16 c n n n p 484 660 2048 484 2048 484 none +b f32 bf16 c n n n p 484 660 2048 484 2048 484 bias,relu,clip +b bf16 bf16 c n n n p 484 660 2048 484 2048 484 bias,relu,clip +b f32 c n n n p 485 660 2048 485 2048 485 none +b bf16 c n n n p 485 660 2048 485 2048 485 none +b f32 bf16 c n n n p 485 660 2048 485 2048 485 bias,relu,clip +b bf16 bf16 c n n n p 485 660 2048 485 2048 485 bias,relu,clip +b f32 c n n n p 480 679 2048 480 2048 480 none +b bf16 c n n n p 480 679 2048 480 2048 480 none +b f32 bf16 c n n n p 480 679 2048 480 2048 480 bias,relu,clip +b bf16 bf16 c n n n p 480 679 2048 480 2048 480 bias,relu,clip +b f32 c n n n p 481 679 2048 481 2048 481 none +b bf16 c n n n p 481 679 2048 481 2048 481 none +b f32 bf16 c n n n p 481 679 2048 481 2048 481 bias,relu,clip +b bf16 bf16 c n n n p 481 679 2048 481 2048 481 bias,relu,clip +b f32 c n n n p 482 679 2048 482 2048 482 none +b bf16 c n n n p 482 679 2048 482 2048 482 none +b f32 bf16 c n n n p 482 679 2048 482 2048 482 bias,relu,clip +b bf16 bf16 c n n n p 482 679 2048 482 2048 482 bias,relu,clip +b f32 c n n n p 483 679 2048 483 2048 483 none +b bf16 c n n n p 483 679 2048 483 2048 483 none +b f32 bf16 c n n n p 483 679 2048 483 2048 483 bias,relu,clip +b bf16 bf16 c n n n p 483 679 2048 483 2048 483 bias,relu,clip +b f32 c n n n p 484 679 2048 484 2048 484 none +b bf16 c n n n p 484 679 2048 484 2048 484 none +b f32 bf16 c n n n p 484 679 2048 484 2048 484 bias,relu,clip +b bf16 bf16 c n n n p 484 679 2048 484 2048 484 bias,relu,clip +b f32 c n n n p 485 679 2048 485 2048 485 none +b bf16 c n n n p 485 679 2048 485 2048 485 none +b f32 bf16 c n n n p 485 679 2048 485 2048 485 bias,relu,clip +b bf16 bf16 c n n n p 485 679 2048 485 2048 485 bias,relu,clip +b f32 c n n n p 480 690 2048 480 2048 480 none +b bf16 c n n n p 480 690 2048 480 2048 480 none +b f32 bf16 c n n n p 480 690 2048 480 2048 480 bias,relu,clip +b bf16 bf16 c n n n p 480 690 2048 480 2048 480 bias,relu,clip +b f32 c n n n p 481 690 2048 481 2048 481 none +b bf16 c n n n p 481 690 2048 481 2048 481 none +b f32 bf16 c n n n p 481 690 2048 481 2048 481 bias,relu,clip +b bf16 bf16 c n n n p 481 690 2048 481 2048 481 bias,relu,clip +b f32 c n n n p 482 690 2048 482 2048 482 none +b bf16 c n n n p 482 690 2048 482 2048 482 none +b f32 bf16 c n n n p 482 690 2048 482 2048 482 bias,relu,clip +b bf16 bf16 c n n n p 482 690 2048 482 2048 482 bias,relu,clip +b f32 c n n n p 483 690 2048 483 2048 483 none +b bf16 c n n n p 483 690 2048 483 2048 483 none +b f32 bf16 c n n n p 483 690 2048 483 2048 483 bias,relu,clip +b bf16 bf16 c n n n p 483 690 2048 483 2048 483 bias,relu,clip +b f32 c n n n p 484 690 2048 484 2048 484 none +b bf16 c n n n p 484 690 2048 484 2048 484 none +b f32 bf16 c n n n p 484 690 2048 484 2048 484 bias,relu,clip +b bf16 bf16 c n n n p 484 690 2048 484 2048 484 bias,relu,clip +b f32 c n n n p 485 690 2048 485 2048 485 none +b bf16 c n n n p 485 690 2048 485 2048 485 none +b f32 bf16 c n n n p 485 690 2048 485 2048 485 bias,relu,clip +b bf16 bf16 c n n n p 485 690 2048 485 2048 485 bias,relu,clip +b f32 c n n n p 480 656 1024 480 1024 480 none +b bf16 c n n n p 480 656 1024 480 1024 480 none +b f32 bf16 c n n n p 480 656 1024 480 1024 480 bias,relu,clip +b bf16 bf16 c n n n p 480 656 1024 480 1024 480 bias,relu,clip +b f32 c n n n p 480 128 3 480 3 480 none +b bf16 c n n n p 480 128 3 480 3 480 none +b f32 bf16 c n n n p 480 128 3 480 3 480 bias,relu,clip +b bf16 bf16 c n n n p 480 128 3 480 3 480 bias,relu,clip +b f32 c n n n p 1024 512 515 1024 515 1024 none +b bf16 c n n n p 1024 512 515 1024 515 1024 none +b f32 bf16 c n n n p 1024 512 515 1024 515 1024 bias,relu,clip +b bf16 bf16 c n n n p 1024 512 515 1024 515 1024 bias,relu,clip +b f32 c n n n p 1024 2048 1024 1024 1024 1024 none +b bf16 c n n n p 1024 2048 1024 1024 1024 1024 none +b f32 bf16 c n n n p 1024 2048 1024 1024 1024 1024 bias,relu,clip +b bf16 bf16 c n n n p 1024 2048 1024 1024 1024 1024 bias,relu,clip +b f32 c n n n p 1024 2048 515 1024 515 1024 none +b bf16 c n n n p 1024 2048 515 1024 515 1024 none +b f32 bf16 c n n n p 1024 2048 515 1024 515 1024 bias,relu,clip +b bf16 bf16 c n n n p 1024 2048 515 1024 515 1024 bias,relu,clip +b f32 c p n n n 1024 1040 515 1024 515 1024 none +b bf16 c p n n n 1024 1040 515 1024 515 1024 none +b f32 bf16 c p n n n 1024 1040 515 1024 515 1024 bias,relu,clip +b bf16 bf16 c p n n n 1024 1040 515 1024 515 1024 bias,relu,clip +b f32 c p n n n 5 1029 515 5 515 5 none +b bf16 c p n n n 5 1029 515 5 515 5 none +b f32 bf16 c p n n n 5 1029 515 5 515 5 bias,relu,clip +b bf16 bf16 c p n n n 5 1029 515 5 515 5 bias,relu,clip +b f32 c p n n n 1024 1029 515 1024 515 1024 none +b bf16 c p n n n 1024 1029 515 1024 515 1024 none +b f32 bf16 c p n n n 1024 1029 515 1024 515 1024 bias,relu,clip +b bf16 bf16 c p n n n 1024 1029 515 1024 515 1024 bias,relu,clip +b f32 c p n n n 1024 1040 2050 1024 2050 1024 none +b bf16 c p n n n 1024 1040 2050 1024 2050 1024 none +b f32 bf16 c p n n n 1024 1040 2050 1024 2050 1024 bias,relu,clip +b bf16 bf16 c p n n n 1024 1040 2050 1024 2050 1024 bias,relu,clip +b f32 c p n n n 1029 1029 2050 1029 2050 1029 none +b bf16 c p n n n 1029 1029 2050 1029 2050 1029 none +b f32 bf16 c p n n n 1029 1029 2050 1029 2050 1029 bias,relu,clip +b bf16 bf16 c p n n n 1029 1029 2050 1029 2050 1029 bias,relu,clip +b f32 c p n n n 485 656 2050 485 2050 485 none +b bf16 c p n n n 485 656 2050 485 2050 485 none +b f32 bf16 c p n n n 485 656 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c p n n n 485 656 2050 485 2050 485 bias,relu,clip +b f32 c p n n n 480 672 2050 480 2050 480 none +b bf16 c p n n n 480 672 2050 480 2050 480 none +b f32 bf16 c p n n n 480 672 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c p n n n 480 672 2050 480 2050 480 bias,relu,clip +b f32 c p n n n 481 672 2050 481 2050 481 none +b bf16 c p n n n 481 672 2050 481 2050 481 none +b f32 bf16 c p n n n 481 672 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c p n n n 481 672 2050 481 2050 481 bias,relu,clip +b f32 c p n n n 482 672 2050 482 2050 482 none +b bf16 c p n n n 482 672 2050 482 2050 482 none +b f32 bf16 c p n n n 482 672 2050 482 2050 482 bias,relu,clip +b bf16 bf16 c p n n n 482 672 2050 482 2050 482 bias,relu,clip +b f32 c p n n n 483 672 2050 483 2050 483 none +b bf16 c p n n n 483 672 2050 483 2050 483 none +b f32 bf16 c p n n n 483 672 2050 483 2050 483 bias,relu,clip +b bf16 bf16 c p n n n 483 672 2050 483 2050 483 bias,relu,clip +b f32 c p n n n 484 672 2050 484 2050 484 none +b bf16 c p n n n 484 672 2050 484 2050 484 none +b f32 bf16 c p n n n 484 672 2050 484 2050 484 bias,relu,clip +b bf16 bf16 c p n n n 484 672 2050 484 2050 484 bias,relu,clip +b f32 c p n n n 485 672 2050 485 2050 485 none +b bf16 c p n n n 485 672 2050 485 2050 485 none +b f32 bf16 c p n n n 485 672 2050 485 2050 485 bias,relu,clip +b bf16 bf16 c p n n n 485 672 2050 485 2050 485 bias,relu,clip +b f32 c p n n n 480 688 2050 480 2050 480 none +b bf16 c p n n n 480 688 2050 480 2050 480 none +b f32 bf16 c p n n n 480 688 2050 480 2050 480 bias,relu,clip +b bf16 bf16 c p n n n 480 688 2050 480 2050 480 bias,relu,clip +b f32 c p n n n 481 688 2050 481 2050 481 none +b bf16 c p n n n 481 688 2050 481 2050 481 none +b f32 bf16 c p n n n 481 688 2050 481 2050 481 bias,relu,clip +b bf16 bf16 c p n n n 481 688 2050 481 2050 481 bias,relu,clip +b f32 c p n n n 1024 32 256 1024 256 1024 none +b bf16 c p n n n 1024 32 256 1024 256 1024 none +b f32 bf16 c p n n n 1024 32 256 1024 256 1024 bias,relu,clip +b bf16 bf16 c p n n n 1024 32 256 1024 256 1024 bias,relu,clip +b f32 c P n n n 1024 64 512 1024 512 1024 none +b bf16 c P n n n 1024 64 512 1024 512 1024 none +b f32 bf16 c P n n n 1024 64 512 1024 512 1024 bias,relu,clip +b bf16 bf16 c P n n n 1024 64 512 1024 512 1024 bias,relu,clip +b f32 c P n n n 64 800 320 64 320 64 none +b bf16 c P n n n 64 800 320 64 320 64 none +b f32 bf16 c P n n n 64 800 320 64 320 64 bias,relu,clip +b bf16 bf16 c P n n n 64 800 320 64 320 64 bias,relu,clip +b f32 c P n n n 64 768 512 64 512 64 none +b bf16 c P n n n 64 768 512 64 512 64 none +b f32 bf16 c P n n n 64 768 512 64 512 64 bias,relu,clip +b bf16 bf16 c P n n n 64 768 512 64 512 64 bias,relu,clip +b f32 c P n n n 16 256 512 16 512 16 none +b bf16 c P n n n 16 256 512 16 512 16 none +b f32 bf16 c P n n n 16 256 512 16 512 16 bias,relu,clip +b bf16 bf16 c P n n n 16 256 512 16 512 16 bias,relu,clip +b f32 c P n n n 128 128 128 128 128 128 none +b bf16 c P n n n 128 128 128 128 128 128 none +b f32 bf16 c P n n n 128 128 128 128 128 128 bias,relu,clip +b bf16 bf16 c P n n n 128 128 128 128 128 128 bias,relu,clip +b f32 c P n n n 256 512 256 256 256 256 none +b bf16 c P n n n 256 512 256 256 256 256 none +b f32 bf16 c P n n n 256 512 256 256 256 256 bias,relu,clip +b bf16 bf16 c P n n n 256 512 256 256 256 256 bias,relu,clip +b f32 c P n n n 1024 1024 1024 1024 1024 1024 none +b bf16 c P n n n 1024 1024 1024 1024 1024 1024 none +b f32 bf16 c P n n n 1024 1024 1024 1024 1024 1024 bias,relu,clip +b bf16 bf16 c P n n n 1024 1024 1024 1024 1024 1024 bias,relu,clip +b f32 c P n n n 480 640 1024 480 1024 480 none +b bf16 c P n n n 480 640 1024 480 1024 480 none +b f32 bf16 c P n n n 480 640 1024 480 1024 480 bias,relu,clip +b bf16 bf16 c P n n n 480 640 1024 480 1024 480 bias,relu,clip +b f32 c P n n n 480 640 256 480 256 480 none +b bf16 c P n n n 480 640 256 480 256 480 none +b f32 bf16 c P n n n 480 640 256 480 256 480 bias,relu,clip +b bf16 bf16 c P n n n 480 640 256 480 256 480 bias,relu,clip +b f32 c P n n n 8 64 32 8 32 8 none +b bf16 c P n n n 8 64 32 8 32 8 none +b f32 bf16 c P n n n 8 64 32 8 32 8 bias,relu,clip +b bf16 bf16 c P n n n 8 64 32 8 32 8 bias,relu,clip +b f32 c P n n n 9 64 32 9 32 9 none +b bf16 c P n n n 9 64 32 9 32 9 none +b f32 bf16 c P n n n 9 64 32 9 32 9 bias,relu,clip +b bf16 bf16 c P n n n 9 64 32 9 32 9 bias,relu,clip +b f32 c P n n n 10 128 64 10 64 10 none +b bf16 c P n n n 10 128 64 10 64 10 none +b f32 bf16 c P n n n 10 128 64 10 64 10 bias,relu,clip +b bf16 bf16 c P n n n 10 128 64 10 64 10 bias,relu,clip +b f32 c P n n n 8 8 8 8 8 8 none +b bf16 c P n n n 8 8 8 8 8 8 none +b f32 bf16 c P n n n 8 8 8 8 8 8 bias,relu,clip +b bf16 bf16 c P n n n 8 8 8 8 8 8 bias,relu,clip +b f32 c P n n n 12 12 12 12 12 12 none +b bf16 c P n n n 12 12 12 12 12 12 none +b f32 bf16 c P n n n 12 12 12 12 12 12 bias,relu,clip +b bf16 bf16 c P n n n 12 12 12 12 12 12 bias,relu,clip +b f32 c P n n n 25 25 25 25 25 25 none +b bf16 c P n n n 25 25 25 25 25 25 none +b f32 bf16 c P n n n 25 25 25 25 25 25 bias,relu,clip +b bf16 bf16 c P n n n 25 25 25 25 25 25 bias,relu,clip +b f32 c P n n n 25 25 20 25 20 25 none +b bf16 c P n n n 25 25 20 25 20 25 none +b f32 bf16 c P n n n 25 25 20 25 20 25 bias,relu,clip +b bf16 bf16 c P n n n 25 25 20 25 20 25 bias,relu,clip +s s16 r n n n r 480 20 2050 2050 20 20 none +s s8 r n n n r 480 20 2050 2050 20 20 none +s u8 r n n n r 480 20 2050 2050 20 20 none +s s16 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n r 480 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n r 481 20 2050 2050 20 20 none +s s8 r n n n r 481 20 2050 2050 20 20 none +s u8 r n n n r 481 20 2050 2050 20 20 none +s s16 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n r 481 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n r 482 20 2050 2050 20 20 none +s s8 r n n n r 482 20 2050 2050 20 20 none +s u8 r n n n r 482 20 2050 2050 20 20 none +s s16 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n r 482 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n p 483 20 2050 2050 20 20 none +s s8 r n n n p 483 20 2050 2050 20 20 none +s u8 r n n n p 483 20 2050 2050 20 20 none +s s16 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n R 484 20 2050 2050 20 20 none +s s8 r n n n R 484 20 2050 2050 20 20 none +s u8 r n n n R 484 20 2050 2050 20 20 none +s s16 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n R 485 20 2050 2050 20 20 none +s s8 r n n n R 485 20 2050 2050 20 20 none +s u8 r n n n R 485 20 2050 2050 20 20 none +s s16 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +s s8 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +s u8 u8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +s s16 r n n n R 480 39 2050 2050 39 39 none +s s8 r n n n R 480 39 2050 2050 39 39 none +s u8 r n n n R 480 39 2050 2050 39 39 none +s s16 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n R 481 39 2050 2050 39 39 none +s s8 r n n n R 481 39 2050 2050 39 39 none +s u8 r n n n R 481 39 2050 2050 39 39 none +s s16 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n R 482 39 2050 2050 39 39 none +s s8 r n n n R 482 39 2050 2050 39 39 none +s u8 r n n n R 482 39 2050 2050 39 39 none +s s16 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n R 483 39 2050 2050 39 39 none +s s8 r n n n R 483 39 2050 2050 39 39 none +s u8 r n n n R 483 39 2050 2050 39 39 none +s s16 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n R 484 39 2050 2050 39 39 none +s s8 r n n n R 484 39 2050 2050 39 39 none +s u8 r n n n R 484 39 2050 2050 39 39 none +s s16 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n p 485 39 2050 2050 39 39 none +s s8 r n n n p 485 39 2050 2050 39 39 none +s u8 r n n n p 485 39 2050 2050 39 39 none +s s16 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +s s8 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +s u8 u8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +s s16 r n n n p 480 50 2050 2050 50 50 none +s s8 r n n n p 480 50 2050 2050 50 50 none +s u8 r n n n p 480 50 2050 2050 50 50 none +s s16 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n p 481 50 2050 2050 50 50 none +s s8 r n n n p 481 50 2050 2050 50 50 none +s u8 r n n n p 481 50 2050 2050 50 50 none +s s16 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n p 482 50 2050 2050 50 50 none +s s8 r n n n p 482 50 2050 2050 50 50 none +s u8 r n n n p 482 50 2050 2050 50 50 none +s s16 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n p 483 50 2050 2050 50 50 none +s s8 r n n n p 483 50 2050 2050 50 50 none +s u8 r n n n p 483 50 2050 2050 50 50 none +s s16 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n p 484 50 2050 2050 50 50 none +s s8 r n n n p 484 50 2050 2050 50 50 none +s u8 r n n n p 484 50 2050 2050 50 50 none +s s16 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n p 485 50 2050 2050 50 50 none +s s8 r n n n p 485 50 2050 2050 50 50 none +s u8 r n n n p 485 50 2050 2050 50 50 none +s s16 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +s s8 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +s u8 u8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +s s16 r n n n R 480 1108 2050 2050 1108 1108 none +s s8 r n n n R 480 1108 2050 2050 1108 1108 none +s u8 r n n n R 480 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 481 1108 2050 2050 1108 1108 none +s s8 r n n n R 481 1108 2050 2050 1108 1108 none +s u8 r n n n R 481 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 482 1108 2050 2050 1108 1108 none +s s8 r n n n R 482 1108 2050 2050 1108 1108 none +s u8 r n n n R 482 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 483 1108 2050 2050 1108 1108 none +s s8 r n n n R 483 1108 2050 2050 1108 1108 none +s u8 r n n n R 483 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 484 1108 2050 2050 1108 1108 none +s s8 r n n n R 484 1108 2050 2050 1108 1108 none +s u8 r n n n R 484 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 485 1108 2050 2050 1108 1108 none +s s8 r n n n R 485 1108 2050 2050 1108 1108 none +s u8 r n n n R 485 1108 2050 2050 1108 1108 none +s s16 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +s s8 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +s u8 u8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +s s16 r n n n R 480 1127 2050 2050 1127 1127 none +s s8 r n n n R 480 1127 2050 2050 1127 1127 none +s u8 r n n n R 480 1127 2050 2050 1127 1127 none +s s16 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n R 481 1127 2050 2050 1127 1127 none +s s8 r n n n R 481 1127 2050 2050 1127 1127 none +s u8 r n n n R 481 1127 2050 2050 1127 1127 none +s s16 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n R 482 1127 2050 2050 1127 1127 none +s s8 r n n n R 482 1127 2050 2050 1127 1127 none +s u8 r n n n R 482 1127 2050 2050 1127 1127 none +s s16 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n R 483 1127 2050 2050 1127 1127 none +s s8 r n n n R 483 1127 2050 2050 1127 1127 none +s u8 r n n n R 483 1127 2050 2050 1127 1127 none +s s16 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n p 484 1127 2050 2050 1127 1127 none +s s8 r n n n p 484 1127 2050 2050 1127 1127 none +s u8 r n n n p 484 1127 2050 2050 1127 1127 none +s s16 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n p 485 1127 2050 2050 1127 1127 none +s s8 r n n n p 485 1127 2050 2050 1127 1127 none +s u8 r n n n p 485 1127 2050 2050 1127 1127 none +s s16 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +s s8 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +s u8 u8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +s s16 r n n n p 480 1138 2050 2050 1138 1138 none +s s8 r n n n p 480 1138 2050 2050 1138 1138 none +s u8 r n n n p 480 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 481 1138 2050 2050 1138 1138 none +s s8 r n n n p 481 1138 2050 2050 1138 1138 none +s u8 r n n n p 481 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 482 1138 2050 2050 1138 1138 none +s s8 r n n n p 482 1138 2050 2050 1138 1138 none +s u8 r n n n p 482 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 483 1138 2050 2050 1138 1138 none +s s8 r n n n p 483 1138 2050 2050 1138 1138 none +s u8 r n n n p 483 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 484 1138 2050 2050 1138 1138 none +s s8 r n n n p 484 1138 2050 2050 1138 1138 none +s u8 r n n n p 484 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 485 1138 2050 2050 1138 1138 none +s s8 r n n n p 485 1138 2050 2050 1138 1138 none +s u8 r n n n p 485 1138 2050 2050 1138 1138 none +s s16 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +s s8 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +s u8 u8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +s s16 r n n n p 1 1 3 3 1 1 none +s s8 r n n n p 1 1 3 3 1 1 none +s u8 r n n n p 1 1 3 3 1 1 none +s s16 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip +s s8 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip +s u8 u8 r n n n p 1 1 3 3 1 1 bias,relu,clip +s s16 r n n n p 1 9 3 3 9 9 none +s s8 r n n n p 1 9 3 3 9 9 none +s u8 r n n n p 1 9 3 3 9 9 none +s s16 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip +s s8 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip +s u8 u8 r n n n p 1 9 3 3 9 9 bias,relu,clip +s s16 r n n n p 1 2048 3 3 2048 2048 none +s s8 r n n n p 1 2048 3 3 2048 2048 none +s u8 r n n n p 1 2048 3 3 2048 2048 none +s s16 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +s s8 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +s u8 u8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +s s16 r n n n p 1 2048 5192 5192 2048 2048 none +s s8 r n n n p 1 2048 5192 5192 2048 2048 none +s u8 r n n n p 1 2048 5192 5192 2048 2048 none +s s16 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +s s8 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +s u8 u8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +s s16 r n n n p 9 1 3 3 1 1 none +s s8 r n n n p 9 1 3 3 1 1 none +s u8 r n n n p 9 1 3 3 1 1 none +s s16 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip +s s8 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip +s u8 u8 r n n n p 9 1 3 3 1 1 bias,relu,clip +s s16 r n n n p 576 1 3500 3500 1 1 none +s s8 r n n n p 576 1 3500 3500 1 1 none +s u8 r n n n p 576 1 3500 3500 1 1 none +s s16 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +s s8 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +s u8 u8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +s s16 r n n n p 1 1 1 1 1 1 none +s s8 r n n n p 1 1 1 1 1 1 none +s u8 r n n n p 1 1 1 1 1 1 none +s s16 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip +s s8 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip +s u8 u8 r n n n p 1 1 1 1 1 1 bias,relu,clip +s s16 r n n n p 102 1088 1024 1024 1088 1088 none +s s8 r n n n p 102 1088 1024 1024 1088 1088 none +s u8 r n n n p 102 1088 1024 1024 1088 1088 none +s s16 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +s s8 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +s u8 u8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +s s16 r n n n p 102 2048 1024 1024 2048 2048 none +s s8 r n n n p 102 2048 1024 1024 2048 2048 none +s u8 r n n n p 102 2048 1024 1024 2048 2048 none +s s16 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +s s8 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +s u8 u8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +s s16 r n n n p 485 656 1024 1024 656 656 none +s s8 r n n n p 485 656 1024 1024 656 656 none +s u8 r n n n p 485 656 1024 1024 656 656 none +s s16 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +s s8 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +s u8 u8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +s s16 r n n n p 483 656 1024 1024 656 656 none +s s8 r n n n p 483 656 1024 1024 656 656 none +s u8 r n n n p 483 656 1024 1024 656 656 none +s s16 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +s s8 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +s u8 u8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +s s16 r n n n p 81 128 3 3 128 128 none +s s8 r n n n p 81 128 3 3 128 128 none +s u8 r n n n p 81 128 3 3 128 128 none +s s16 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip +s s8 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip +s u8 u8 r n n n p 81 128 3 3 128 128 bias,relu,clip +s s16 r n n n p 1022 512 515 515 512 512 none +s s8 r n n n p 1022 512 515 515 512 512 none +s u8 r n n n p 1022 512 515 515 512 512 none +s s16 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +s s8 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +s u8 u8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +s s16 r n n n p 74 512 515 515 512 512 none +s s8 r n n n p 74 512 515 515 512 512 none +s u8 r n n n p 74 512 515 515 512 512 none +s s16 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip +s s8 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip +s u8 u8 r n n n p 74 512 515 515 512 512 bias,relu,clip +s s16 r n n n p 253 2048 515 515 2048 2048 none +s s8 r n n n p 253 2048 515 515 2048 2048 none +s u8 r n n n p 253 2048 515 515 2048 2048 none +s s16 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +s s8 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +s u8 u8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +s s16 r n n n p 8192 1040 515 515 1040 1040 none +s s8 r n n n p 8192 1040 515 515 1040 1040 none +s u8 r n n n p 8192 1040 515 515 1040 1040 none +s s16 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +s s8 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +s u8 u8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +s s16 r n n n p 10 1029 515 515 1029 1029 none +s s8 r n n n p 10 1029 515 515 1029 1029 none +s u8 r n n n p 10 1029 515 515 1029 1029 none +s s16 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +s s8 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +s u8 u8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +s s16 r n n n p 24 1040 2050 2050 1040 1040 none +s s8 r n n n p 24 1040 2050 2050 1040 1040 none +s u8 r n n n p 24 1040 2050 2050 1040 1040 none +s s16 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +s s8 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +s u8 u8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +s s16 r n n n p 1024 1029 2050 2050 1029 1029 none +s s8 r n n n p 1024 1029 2050 2050 1029 1029 none +s u8 r n n n p 1024 1029 2050 2050 1029 1029 none +s s16 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +s s8 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +s u8 u8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +s s16 r n n n p 480 660 2050 2050 660 660 none +s s8 r n n n p 480 660 2050 2050 660 660 none +s u8 r n n n p 480 660 2050 2050 660 660 none +s s16 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 481 660 2050 2050 660 660 none +s s8 r n n n p 481 660 2050 2050 660 660 none +s u8 r n n n p 481 660 2050 2050 660 660 none +s s16 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 482 660 2050 2050 660 660 none +s s8 r n n n p 482 660 2050 2050 660 660 none +s u8 r n n n p 482 660 2050 2050 660 660 none +s s16 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 483 660 2050 2050 660 660 none +s s8 r n n n p 483 660 2050 2050 660 660 none +s u8 r n n n p 483 660 2050 2050 660 660 none +s s16 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 484 660 2050 2050 660 660 none +s s8 r n n n p 484 660 2050 2050 660 660 none +s u8 r n n n p 484 660 2050 2050 660 660 none +s s16 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 485 660 2050 2050 660 660 none +s s8 r n n n p 485 660 2050 2050 660 660 none +s u8 r n n n p 485 660 2050 2050 660 660 none +s s16 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +s s8 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +s u8 u8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +s s16 r n n n p 480 679 2050 2050 679 679 none +s s8 r n n n p 480 679 2050 2050 679 679 none +s u8 r n n n p 480 679 2050 2050 679 679 none +s s16 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 481 679 2050 2050 679 679 none +s s8 r n n n p 481 679 2050 2050 679 679 none +s u8 r n n n p 481 679 2050 2050 679 679 none +s s16 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 482 679 2050 2050 679 679 none +s s8 r n n n p 482 679 2050 2050 679 679 none +s u8 r n n n p 482 679 2050 2050 679 679 none +s s16 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 483 679 2050 2050 679 679 none +s s8 r n n n p 483 679 2050 2050 679 679 none +s u8 r n n n p 483 679 2050 2050 679 679 none +s s16 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 484 679 2050 2050 679 679 none +s s8 r n n n p 484 679 2050 2050 679 679 none +s u8 r n n n p 484 679 2050 2050 679 679 none +s s16 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 485 679 2050 2050 679 679 none +s s8 r n n n p 485 679 2050 2050 679 679 none +s u8 r n n n p 485 679 2050 2050 679 679 none +s s16 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +s s8 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +s u8 u8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +s s16 r n n n p 480 690 2050 2050 690 690 none +s s8 r n n n p 480 690 2050 2050 690 690 none +s u8 r n n n p 480 690 2050 2050 690 690 none +s s16 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 481 690 2050 2050 690 690 none +s s8 r n n n p 481 690 2050 2050 690 690 none +s u8 r n n n p 481 690 2050 2050 690 690 none +s s16 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 482 690 2050 2050 690 690 none +s s8 r n n n p 482 690 2050 2050 690 690 none +s u8 r n n n p 482 690 2050 2050 690 690 none +s s16 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 483 690 2050 2050 690 690 none +s s8 r n n n p 483 690 2050 2050 690 690 none +s u8 r n n n p 483 690 2050 2050 690 690 none +s s16 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 484 690 2050 2050 690 690 none +s s8 r n n n p 484 690 2050 2050 690 690 none +s u8 r n n n p 484 690 2050 2050 690 690 none +s s16 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 485 690 2050 2050 690 690 none +s s8 r n n n p 485 690 2050 2050 690 690 none +s u8 r n n n p 485 690 2050 2050 690 690 none +s s16 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +s s8 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +s u8 u8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +s s16 r n n n p 480 660 2048 2048 660 660 none +s s8 r n n n p 480 660 2048 2048 660 660 none +s u8 r n n n p 480 660 2048 2048 660 660 none +s s16 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 481 660 2048 2048 660 660 none +s s8 r n n n p 481 660 2048 2048 660 660 none +s u8 r n n n p 481 660 2048 2048 660 660 none +s s16 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 482 660 2048 2048 660 660 none +s s8 r n n n p 482 660 2048 2048 660 660 none +s u8 r n n n p 482 660 2048 2048 660 660 none +s s16 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 483 660 2048 2048 660 660 none +s s8 r n n n p 483 660 2048 2048 660 660 none +s u8 r n n n p 483 660 2048 2048 660 660 none +s s16 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 484 660 2048 2048 660 660 none +s s8 r n n n p 484 660 2048 2048 660 660 none +s u8 r n n n p 484 660 2048 2048 660 660 none +s s16 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 485 660 2048 2048 660 660 none +s s8 r n n n p 485 660 2048 2048 660 660 none +s u8 r n n n p 485 660 2048 2048 660 660 none +s s16 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +s s8 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +s u8 u8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +s s16 r n n n p 480 679 2048 2048 679 679 none +s s8 r n n n p 480 679 2048 2048 679 679 none +s u8 r n n n p 480 679 2048 2048 679 679 none +s s16 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 481 679 2048 2048 679 679 none +s s8 r n n n p 481 679 2048 2048 679 679 none +s u8 r n n n p 481 679 2048 2048 679 679 none +s s16 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 482 679 2048 2048 679 679 none +s s8 r n n n p 482 679 2048 2048 679 679 none +s u8 r n n n p 482 679 2048 2048 679 679 none +s s16 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 483 679 2048 2048 679 679 none +s s8 r n n n p 483 679 2048 2048 679 679 none +s u8 r n n n p 483 679 2048 2048 679 679 none +s s16 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 484 679 2048 2048 679 679 none +s s8 r n n n p 484 679 2048 2048 679 679 none +s u8 r n n n p 484 679 2048 2048 679 679 none +s s16 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 485 679 2048 2048 679 679 none +s s8 r n n n p 485 679 2048 2048 679 679 none +s u8 r n n n p 485 679 2048 2048 679 679 none +s s16 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +s s8 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +s u8 u8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +s s16 r n n n p 480 690 2048 2048 690 690 none +s s8 r n n n p 480 690 2048 2048 690 690 none +s u8 r n n n p 480 690 2048 2048 690 690 none +s s16 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 481 690 2048 2048 690 690 none +s s8 r n n n p 481 690 2048 2048 690 690 none +s u8 r n n n p 481 690 2048 2048 690 690 none +s s16 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 482 690 2048 2048 690 690 none +s s8 r n n n p 482 690 2048 2048 690 690 none +s u8 r n n n p 482 690 2048 2048 690 690 none +s s16 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 483 690 2048 2048 690 690 none +s s8 r n n n p 483 690 2048 2048 690 690 none +s u8 r n n n p 483 690 2048 2048 690 690 none +s s16 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 484 690 2048 2048 690 690 none +s s8 r n n n p 484 690 2048 2048 690 690 none +s u8 r n n n p 484 690 2048 2048 690 690 none +s s16 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 485 690 2048 2048 690 690 none +s s8 r n n n p 485 690 2048 2048 690 690 none +s u8 r n n n p 485 690 2048 2048 690 690 none +s s16 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +s s8 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +s u8 u8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +s s16 r n n n p 480 656 1024 1024 656 656 none +s s8 r n n n p 480 656 1024 1024 656 656 none +s u8 r n n n p 480 656 1024 1024 656 656 none +s s16 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +s s8 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +s u8 u8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +s s16 r n n n p 480 128 3 3 128 128 none +s s8 r n n n p 480 128 3 3 128 128 none +s u8 r n n n p 480 128 3 3 128 128 none +s s16 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip +s s8 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip +s u8 u8 r n n n p 480 128 3 3 128 128 bias,relu,clip +s s16 r n n n p 1024 512 515 515 512 512 none +s s8 r n n n p 1024 512 515 515 512 512 none +s u8 r n n n p 1024 512 515 515 512 512 none +s s16 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +s s8 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +s u8 u8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +s s16 r n n n p 1024 2048 1024 1024 2048 2048 none +s s8 r n n n p 1024 2048 1024 1024 2048 2048 none +s u8 r n n n p 1024 2048 1024 1024 2048 2048 none +s s16 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +s s8 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +s u8 u8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +s s16 r n n n p 1024 2048 515 515 2048 2048 none +s s8 r n n n p 1024 2048 515 515 2048 2048 none +s u8 r n n n p 1024 2048 515 515 2048 2048 none +s s16 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +s s8 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +s u8 u8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +s s16 r n n n p 1024 1040 515 515 1040 1040 none +s s8 r n n n p 1024 1040 515 515 1040 1040 none +s u8 r n n n p 1024 1040 515 515 1040 1040 none +s s16 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +s s8 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +s u8 u8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +s s16 r n n n p 5 1029 515 515 1029 1029 none +s s8 r n n n p 5 1029 515 515 1029 1029 none +s u8 r n n n p 5 1029 515 515 1029 1029 none +s s16 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +s s8 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +s u8 u8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +s s16 r n n n p 1024 1029 515 515 1029 1029 none +s s8 r n n n p 1024 1029 515 515 1029 1029 none +s u8 r n n n p 1024 1029 515 515 1029 1029 none +s s16 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +s s8 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +s u8 u8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +s s16 r n n n p 1024 1040 2050 2050 1040 1040 none +s s8 r n n n p 1024 1040 2050 2050 1040 1040 none +s u8 r n n n p 1024 1040 2050 2050 1040 1040 none +s s16 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +s s8 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +s u8 u8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +s s16 r n n n p 1029 1029 2050 2050 1029 1029 none +s s8 r n n n p 1029 1029 2050 2050 1029 1029 none +s u8 r n n n p 1029 1029 2050 2050 1029 1029 none +s s16 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +s s8 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +s u8 u8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +s s16 r n n n R 480 646 2050 2050 646 646 none +s s8 r n n n R 480 646 2050 2050 646 646 none +s u8 r n n n R 480 646 2050 2050 646 646 none +s s16 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 481 646 2050 2050 646 646 none +s s8 r n n n R 481 646 2050 2050 646 646 none +s u8 r n n n R 481 646 2050 2050 646 646 none +s s16 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 482 646 2050 2050 646 646 none +s s8 r n n n R 482 646 2050 2050 646 646 none +s u8 r n n n R 482 646 2050 2050 646 646 none +s s16 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 483 646 2050 2050 646 646 none +s s8 r n n n R 483 646 2050 2050 646 646 none +s u8 r n n n R 483 646 2050 2050 646 646 none +s s16 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 484 646 2050 2050 646 646 none +s s8 r n n n R 484 646 2050 2050 646 646 none +s u8 r n n n R 484 646 2050 2050 646 646 none +s s16 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 485 646 2050 2050 646 646 none +s s8 r n n n R 485 646 2050 2050 646 646 none +s u8 r n n n R 485 646 2050 2050 646 646 none +s s16 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +s s8 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +s u8 u8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +s s16 r n n n R 481 656 2050 2050 656 656 none +s s8 r n n n R 481 656 2050 2050 656 656 none +s u8 r n n n R 481 656 2050 2050 656 656 none +s s16 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +s s8 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +s u8 u8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +s s16 r n n n R 482 656 2050 2050 656 656 none +s s8 r n n n R 482 656 2050 2050 656 656 none +s u8 r n n n R 482 656 2050 2050 656 656 none +s s16 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +s s8 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +s u8 u8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +s s16 r n n n R 483 656 2050 2050 656 656 none +s s8 r n n n R 483 656 2050 2050 656 656 none +s u8 r n n n R 483 656 2050 2050 656 656 none +s s16 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +s s8 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +s u8 u8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +s s16 r n n n R 484 656 2050 2050 656 656 none +s s8 r n n n R 484 656 2050 2050 656 656 none +s u8 r n n n R 484 656 2050 2050 656 656 none +s s16 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +s s8 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +s u8 u8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +s s16 r n n n p 485 656 2050 2050 656 656 none +s s8 r n n n p 485 656 2050 2050 656 656 none +s u8 r n n n p 485 656 2050 2050 656 656 none +s s16 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +s s8 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +s u8 u8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +s s16 r n n n p 480 672 2050 2050 672 672 none +s s8 r n n n p 480 672 2050 2050 672 672 none +s u8 r n n n p 480 672 2050 2050 672 672 none +s s16 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 481 672 2050 2050 672 672 none +s s8 r n n n p 481 672 2050 2050 672 672 none +s u8 r n n n p 481 672 2050 2050 672 672 none +s s16 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 482 672 2050 2050 672 672 none +s s8 r n n n p 482 672 2050 2050 672 672 none +s u8 r n n n p 482 672 2050 2050 672 672 none +s s16 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 483 672 2050 2050 672 672 none +s s8 r n n n p 483 672 2050 2050 672 672 none +s u8 r n n n p 483 672 2050 2050 672 672 none +s s16 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 484 672 2050 2050 672 672 none +s s8 r n n n p 484 672 2050 2050 672 672 none +s u8 r n n n p 484 672 2050 2050 672 672 none +s s16 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 485 672 2050 2050 672 672 none +s s8 r n n n p 485 672 2050 2050 672 672 none +s u8 r n n n p 485 672 2050 2050 672 672 none +s s16 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +s s8 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +s u8 u8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +s s16 r n n n p 480 688 2050 2050 688 688 none +s s8 r n n n p 480 688 2050 2050 688 688 none +s u8 r n n n p 480 688 2050 2050 688 688 none +s s16 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n p 481 688 2050 2050 688 688 none +s s8 r n n n p 481 688 2050 2050 688 688 none +s u8 r n n n p 481 688 2050 2050 688 688 none +s s16 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n r 482 688 2050 2050 688 688 none +s s8 r n n n r 482 688 2050 2050 688 688 none +s u8 r n n n r 482 688 2050 2050 688 688 none +s s16 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n r 483 688 2050 2050 688 688 none +s s8 r n n n r 483 688 2050 2050 688 688 none +s u8 r n n n r 483 688 2050 2050 688 688 none +s s16 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n r 484 688 2050 2050 688 688 none +s s8 r n n n r 484 688 2050 2050 688 688 none +s u8 r n n n r 484 688 2050 2050 688 688 none +s s16 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n r 485 688 2050 2050 688 688 none +s s8 r n n n r 485 688 2050 2050 688 688 none +s u8 r n n n r 485 688 2050 2050 688 688 none +s s16 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +s s8 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +s u8 u8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +s s16 r n n n r 1024 512 64 64 512 512 none +s s8 r n n n r 1024 512 64 64 512 512 none +s u8 r n n n r 1024 512 64 64 512 512 none +s s16 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s s8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s u8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s s16 r n n n r 16 256 512 512 256 256 none +s s8 r n n n r 16 256 512 512 256 256 none +s u8 r n n n r 16 256 512 512 256 256 none +s s16 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip +s s8 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip +s u8 u8 r n n n r 16 256 512 512 256 256 bias,relu,clip +s s16 r n n n r 480 640 512 512 640 640 none +s s8 r n n n r 480 640 512 512 640 640 none +s u8 r n n n r 480 640 512 512 640 640 none +s s16 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s s8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s u8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s s16 r n n n r 64 768 512 512 768 768 none +s s8 r n n n r 64 768 512 512 768 768 none +s u8 r n n n r 64 768 512 512 768 768 none +s s16 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip +s s8 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip +s u8 u8 r n n n r 64 768 512 512 768 768 bias,relu,clip +s s16 r n n n r 128 128 128 128 128 128 none +s s8 r n n n r 128 128 128 128 128 128 none +s u8 r n n n r 128 128 128 128 128 128 none +s s16 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip +s s8 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip +s u8 u8 r n n n r 128 128 128 128 128 128 bias,relu,clip +s s16 r n n n r 1024 64 512 512 64 64 none +s s8 r n n n r 1024 64 512 512 64 64 none +s u8 r n n n r 1024 64 512 512 64 64 none +s s16 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +s s8 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +s u8 u8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +s s16 r n n n r 1024 256 32 32 256 256 none +s s8 r n n n r 1024 256 32 32 256 256 none +s u8 r n n n r 1024 256 32 32 256 256 none +s s16 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +s s8 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +s u8 u8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +s s16 r n n n r 1024 512 64 64 512 512 none +s s8 r n n n r 1024 512 64 64 512 512 none +s u8 r n n n r 1024 512 64 64 512 512 none +s s16 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s s8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s u8 u8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +s s16 r n n n r 480 640 512 512 640 640 none +s s8 r n n n r 480 640 512 512 640 640 none +s u8 r n n n r 480 640 512 512 640 640 none +s s16 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s s8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s u8 u8 r n n n r 480 640 512 512 640 640 bias,relu,clip +s s16 r n n n p 1024 32 256 256 32 32 none +s s8 r n n n p 1024 32 256 256 32 32 none +s u8 r n n n p 1024 32 256 256 32 32 none +s s16 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +s s8 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +s u8 u8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +s s16 r n n n P 1024 64 512 512 64 64 none +s s8 r n n n P 1024 64 512 512 64 64 none +s u8 r n n n P 1024 64 512 512 64 64 none +s s16 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +s s8 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +s u8 u8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +s s16 r n n n P 64 800 320 320 800 800 none +s s8 r n n n P 64 800 320 320 800 800 none +s u8 r n n n P 64 800 320 320 800 800 none +s s16 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip +s s8 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip +s u8 u8 r n n n P 64 800 320 320 800 800 bias,relu,clip +s s16 r n n n P 64 768 512 512 768 768 none +s s8 r n n n P 64 768 512 512 768 768 none +s u8 r n n n P 64 768 512 512 768 768 none +s s16 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip +s s8 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip +s u8 u8 r n n n P 64 768 512 512 768 768 bias,relu,clip +s s16 r n n n P 16 256 512 512 256 256 none +s s8 r n n n P 16 256 512 512 256 256 none +s u8 r n n n P 16 256 512 512 256 256 none +s s16 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip +s s8 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip +s u8 u8 r n n n P 16 256 512 512 256 256 bias,relu,clip +s s16 r n n n P 128 128 128 128 128 128 none +s s8 r n n n P 128 128 128 128 128 128 none +s u8 r n n n P 128 128 128 128 128 128 none +s s16 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip +s s8 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip +s u8 u8 r n n n P 128 128 128 128 128 128 bias,relu,clip +s s16 r n n n P 256 512 256 256 512 512 none +s s8 r n n n P 256 512 256 256 512 512 none +s u8 r n n n P 256 512 256 256 512 512 none +s s16 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip +s s8 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip +s u8 u8 r n n n P 256 512 256 256 512 512 bias,relu,clip +s s16 r n n n P 1024 1024 1024 1024 1024 1024 none +s s8 r n n n P 1024 1024 1024 1024 1024 1024 none +s u8 r n n n P 1024 1024 1024 1024 1024 1024 none +s s16 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +s s8 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +s u8 u8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +s s16 r n n n P 480 640 1024 1024 640 640 none +s s8 r n n n P 480 640 1024 1024 640 640 none +s u8 r n n n P 480 640 1024 1024 640 640 none +s s16 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +s s8 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +s u8 u8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +s s16 r n n n P 480 640 256 256 640 640 none +s s8 r n n n P 480 640 256 256 640 640 none +s u8 r n n n P 480 640 256 256 640 640 none +s s16 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip +s s8 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip +s u8 u8 r n n n P 480 640 256 256 640 640 bias,relu,clip +s s16 r n n n P 8 64 32 32 64 64 none +s s8 r n n n P 8 64 32 32 64 64 none +s u8 r n n n P 8 64 32 32 64 64 none +s s16 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip +s s8 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip +s u8 u8 r n n n P 8 64 32 32 64 64 bias,relu,clip +s s16 r n n n P 9 64 32 32 64 64 none +s s8 r n n n P 9 64 32 32 64 64 none +s u8 r n n n P 9 64 32 32 64 64 none +s s16 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip +s s8 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip +s u8 u8 r n n n P 9 64 32 32 64 64 bias,relu,clip +s s16 r n n n P 10 128 64 64 128 128 none +s s8 r n n n P 10 128 64 64 128 128 none +s u8 r n n n P 10 128 64 64 128 128 none +s s16 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip +s s8 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip +s u8 u8 r n n n P 10 128 64 64 128 128 bias,relu,clip +s s16 r n n n P 8 8 8 8 8 8 none +s s8 r n n n P 8 8 8 8 8 8 none +s u8 r n n n P 8 8 8 8 8 8 none +s s16 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip +s s8 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip +s u8 u8 r n n n P 8 8 8 8 8 8 bias,relu,clip +s s16 r n n n P 12 12 12 12 12 12 none +s s8 r n n n P 12 12 12 12 12 12 none +s u8 r n n n P 12 12 12 12 12 12 none +s s16 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip +s s8 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip +s u8 u8 r n n n P 12 12 12 12 12 12 bias,relu,clip +s s16 r n n n P 25 25 25 25 25 25 none +s s8 r n n n P 25 25 25 25 25 25 none +s u8 r n n n P 25 25 25 25 25 25 none +s s16 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip +s s8 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip +s u8 u8 r n n n P 25 25 25 25 25 25 bias,relu,clip +s s16 r n n n P 25 25 20 20 25 25 none +s s8 r n n n P 25 25 20 20 25 25 none +s u8 r n n n P 25 25 20 20 25 25 none +s s16 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip +s s8 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip +s u8 u8 r n n n P 25 25 20 20 25 25 bias,relu,clip +i s32 r n n n p 480 20 2050 2050 20 20 none +i s8 r n n n p 480 20 2050 2050 20 20 none +i s32 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n p 480 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n p 481 20 2050 2050 20 20 none +i s8 r n n n p 481 20 2050 2050 20 20 none +i s32 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n p 481 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n p 482 20 2050 2050 20 20 none +i s8 r n n n p 482 20 2050 2050 20 20 none +i s32 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n p 482 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n p 483 20 2050 2050 20 20 none +i s8 r n n n p 483 20 2050 2050 20 20 none +i s32 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n R 484 20 2050 2050 20 20 none +i s8 r n n n R 484 20 2050 2050 20 20 none +i s32 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n R 485 20 2050 2050 20 20 none +i s8 r n n n R 485 20 2050 2050 20 20 none +i s32 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +i s8 s8 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +i s32 r n n n R 480 39 2050 2050 39 39 none +i s8 r n n n R 480 39 2050 2050 39 39 none +i s32 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n R 481 39 2050 2050 39 39 none +i s8 r n n n R 481 39 2050 2050 39 39 none +i s32 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n R 482 39 2050 2050 39 39 none +i s8 r n n n R 482 39 2050 2050 39 39 none +i s32 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n R 483 39 2050 2050 39 39 none +i s8 r n n n R 483 39 2050 2050 39 39 none +i s32 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n R 484 39 2050 2050 39 39 none +i s8 r n n n R 484 39 2050 2050 39 39 none +i s32 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n p 485 39 2050 2050 39 39 none +i s8 r n n n p 485 39 2050 2050 39 39 none +i s32 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +i s8 s8 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +i s32 r n n n p 480 50 2050 2050 50 50 none +i s8 r n n n p 480 50 2050 2050 50 50 none +i s32 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n p 481 50 2050 2050 50 50 none +i s8 r n n n p 481 50 2050 2050 50 50 none +i s32 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n p 482 50 2050 2050 50 50 none +i s8 r n n n p 482 50 2050 2050 50 50 none +i s32 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n p 483 50 2050 2050 50 50 none +i s8 r n n n p 483 50 2050 2050 50 50 none +i s32 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n p 484 50 2050 2050 50 50 none +i s8 r n n n p 484 50 2050 2050 50 50 none +i s32 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n p 485 50 2050 2050 50 50 none +i s8 r n n n p 485 50 2050 2050 50 50 none +i s32 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +i s8 s8 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +i s32 r n n n R 480 1108 2050 2050 1108 1108 none +i s8 r n n n R 480 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 481 1108 2050 2050 1108 1108 none +i s8 r n n n R 481 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 482 1108 2050 2050 1108 1108 none +i s8 r n n n R 482 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 483 1108 2050 2050 1108 1108 none +i s8 r n n n R 483 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 484 1108 2050 2050 1108 1108 none +i s8 r n n n R 484 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 485 1108 2050 2050 1108 1108 none +i s8 r n n n R 485 1108 2050 2050 1108 1108 none +i s32 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +i s8 s8 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +i s32 r n n n R 480 1127 2050 2050 1127 1127 none +i s8 r n n n R 480 1127 2050 2050 1127 1127 none +i s32 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n R 481 1127 2050 2050 1127 1127 none +i s8 r n n n R 481 1127 2050 2050 1127 1127 none +i s32 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n R 482 1127 2050 2050 1127 1127 none +i s8 r n n n R 482 1127 2050 2050 1127 1127 none +i s32 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n R 483 1127 2050 2050 1127 1127 none +i s8 r n n n R 483 1127 2050 2050 1127 1127 none +i s32 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n p 484 1127 2050 2050 1127 1127 none +i s8 r n n n p 484 1127 2050 2050 1127 1127 none +i s32 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n p 485 1127 2050 2050 1127 1127 none +i s8 r n n n p 485 1127 2050 2050 1127 1127 none +i s32 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +i s8 s8 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +i s32 r n n n p 480 1138 2050 2050 1138 1138 none +i s8 r n n n p 480 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 481 1138 2050 2050 1138 1138 none +i s8 r n n n p 481 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 482 1138 2050 2050 1138 1138 none +i s8 r n n n p 482 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 483 1138 2050 2050 1138 1138 none +i s8 r n n n p 483 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 484 1138 2050 2050 1138 1138 none +i s8 r n n n p 484 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 485 1138 2050 2050 1138 1138 none +i s8 r n n n p 485 1138 2050 2050 1138 1138 none +i s32 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +i s8 s8 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +i s32 r n n n p 1 1 3 3 1 1 none +i s8 r n n n p 1 1 3 3 1 1 none +i s32 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip +i s8 s8 r n n n p 1 1 3 3 1 1 bias,relu,clip +i s32 r n n n p 1 9 3 3 9 9 none +i s8 r n n n p 1 9 3 3 9 9 none +i s32 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip +i s8 s8 r n n n p 1 9 3 3 9 9 bias,relu,clip +i s32 r n n n p 1 2048 3 3 2048 2048 none +i s8 r n n n p 1 2048 3 3 2048 2048 none +i s32 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +i s8 s8 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +i s32 r n n n p 1 2048 5192 5192 2048 2048 none +i s8 r n n n p 1 2048 5192 5192 2048 2048 none +i s32 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +i s8 s8 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +i s32 r n n n p 9 1 3 3 1 1 none +i s8 r n n n p 9 1 3 3 1 1 none +i s32 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip +i s8 s8 r n n n p 9 1 3 3 1 1 bias,relu,clip +i s32 r n n n p 576 1 3500 3500 1 1 none +i s8 r n n n p 576 1 3500 3500 1 1 none +i s32 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +i s8 s8 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +i s32 r n n n p 1 1 1 1 1 1 none +i s8 r n n n p 1 1 1 1 1 1 none +i s32 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip +i s8 s8 r n n n p 1 1 1 1 1 1 bias,relu,clip +i s32 r n n n p 102 1088 1024 1024 1088 1088 none +i s8 r n n n p 102 1088 1024 1024 1088 1088 none +i s32 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +i s8 s8 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +i s32 r n n n p 102 2048 1024 1024 2048 2048 none +i s8 r n n n p 102 2048 1024 1024 2048 2048 none +i s32 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +i s8 s8 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +i s32 r n n n p 485 656 1024 1024 656 656 none +i s8 r n n n p 485 656 1024 1024 656 656 none +i s32 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +i s8 s8 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +i s32 r n n n p 483 656 1024 1024 656 656 none +i s8 r n n n p 483 656 1024 1024 656 656 none +i s32 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +i s8 s8 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +i s32 r n n n p 81 128 3 3 128 128 none +i s8 r n n n p 81 128 3 3 128 128 none +i s32 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip +i s8 s8 r n n n p 81 128 3 3 128 128 bias,relu,clip +i s32 r n n n p 1022 512 515 515 512 512 none +i s8 r n n n p 1022 512 515 515 512 512 none +i s32 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +i s8 s8 r n n n p 1022 512 515 515 512 512 bias,relu,clip +i s32 r n n n p 74 512 515 515 512 512 none +i s8 r n n n p 74 512 515 515 512 512 none +i s32 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip +i s8 s8 r n n n p 74 512 515 515 512 512 bias,relu,clip +i s32 r n n n p 253 2048 515 515 2048 2048 none +i s8 r n n n p 253 2048 515 515 2048 2048 none +i s32 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +i s8 s8 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +i s32 r n n n p 8192 1040 515 515 1040 1040 none +i s8 r n n n p 8192 1040 515 515 1040 1040 none +i s32 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +i s8 s8 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +i s32 r n n n p 10 1029 515 515 1029 1029 none +i s8 r n n n p 10 1029 515 515 1029 1029 none +i s32 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +i s8 s8 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +i s32 r n n n p 24 1040 2050 2050 1040 1040 none +i s8 r n n n p 24 1040 2050 2050 1040 1040 none +i s32 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +i s8 s8 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +i s32 r n n n p 1024 1029 2050 2050 1029 1029 none +i s8 r n n n p 1024 1029 2050 2050 1029 1029 none +i s32 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +i s8 s8 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +i s32 r n n n p 480 660 2050 2050 660 660 none +i s8 r n n n p 480 660 2050 2050 660 660 none +i s32 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 481 660 2050 2050 660 660 none +i s8 r n n n p 481 660 2050 2050 660 660 none +i s32 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 482 660 2050 2050 660 660 none +i s8 r n n n p 482 660 2050 2050 660 660 none +i s32 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 483 660 2050 2050 660 660 none +i s8 r n n n p 483 660 2050 2050 660 660 none +i s32 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 484 660 2050 2050 660 660 none +i s8 r n n n p 484 660 2050 2050 660 660 none +i s32 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 485 660 2050 2050 660 660 none +i s8 r n n n p 485 660 2050 2050 660 660 none +i s32 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +i s8 s8 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +i s32 r n n n p 480 679 2050 2050 679 679 none +i s8 r n n n p 480 679 2050 2050 679 679 none +i s32 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 481 679 2050 2050 679 679 none +i s8 r n n n p 481 679 2050 2050 679 679 none +i s32 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 482 679 2050 2050 679 679 none +i s8 r n n n p 482 679 2050 2050 679 679 none +i s32 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 483 679 2050 2050 679 679 none +i s8 r n n n p 483 679 2050 2050 679 679 none +i s32 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 484 679 2050 2050 679 679 none +i s8 r n n n p 484 679 2050 2050 679 679 none +i s32 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 485 679 2050 2050 679 679 none +i s8 r n n n p 485 679 2050 2050 679 679 none +i s32 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +i s8 s8 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +i s32 r n n n p 480 690 2050 2050 690 690 none +i s8 r n n n p 480 690 2050 2050 690 690 none +i s32 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 481 690 2050 2050 690 690 none +i s8 r n n n p 481 690 2050 2050 690 690 none +i s32 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 482 690 2050 2050 690 690 none +i s8 r n n n p 482 690 2050 2050 690 690 none +i s32 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 483 690 2050 2050 690 690 none +i s8 r n n n p 483 690 2050 2050 690 690 none +i s32 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 484 690 2050 2050 690 690 none +i s8 r n n n p 484 690 2050 2050 690 690 none +i s32 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 485 690 2050 2050 690 690 none +i s8 r n n n p 485 690 2050 2050 690 690 none +i s32 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +i s8 s8 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +i s32 r n n n p 480 660 2048 2048 660 660 none +i s8 r n n n p 480 660 2048 2048 660 660 none +i s32 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 481 660 2048 2048 660 660 none +i s8 r n n n p 481 660 2048 2048 660 660 none +i s32 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 482 660 2048 2048 660 660 none +i s8 r n n n p 482 660 2048 2048 660 660 none +i s32 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 483 660 2048 2048 660 660 none +i s8 r n n n p 483 660 2048 2048 660 660 none +i s32 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 484 660 2048 2048 660 660 none +i s8 r n n n p 484 660 2048 2048 660 660 none +i s32 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 485 660 2048 2048 660 660 none +i s8 r n n n p 485 660 2048 2048 660 660 none +i s32 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +i s8 s8 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +i s32 r n n n p 480 679 2048 2048 679 679 none +i s8 r n n n p 480 679 2048 2048 679 679 none +i s32 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 481 679 2048 2048 679 679 none +i s8 r n n n p 481 679 2048 2048 679 679 none +i s32 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 482 679 2048 2048 679 679 none +i s8 r n n n p 482 679 2048 2048 679 679 none +i s32 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 483 679 2048 2048 679 679 none +i s8 r n n n p 483 679 2048 2048 679 679 none +i s32 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 484 679 2048 2048 679 679 none +i s8 r n n n p 484 679 2048 2048 679 679 none +i s32 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 485 679 2048 2048 679 679 none +i s8 r n n n p 485 679 2048 2048 679 679 none +i s32 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +i s8 s8 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +i s32 r n n n p 480 690 2048 2048 690 690 none +i s8 r n n n p 480 690 2048 2048 690 690 none +i s32 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 481 690 2048 2048 690 690 none +i s8 r n n n p 481 690 2048 2048 690 690 none +i s32 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 482 690 2048 2048 690 690 none +i s8 r n n n p 482 690 2048 2048 690 690 none +i s32 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 483 690 2048 2048 690 690 none +i s8 r n n n p 483 690 2048 2048 690 690 none +i s32 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 484 690 2048 2048 690 690 none +i s8 r n n n p 484 690 2048 2048 690 690 none +i s32 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 485 690 2048 2048 690 690 none +i s8 r n n n p 485 690 2048 2048 690 690 none +i s32 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +i s8 s8 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +i s32 r n n n p 480 656 1024 1024 656 656 none +i s8 r n n n p 480 656 1024 1024 656 656 none +i s32 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +i s8 s8 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +i s32 r n n n p 480 128 3 3 128 128 none +i s8 r n n n p 480 128 3 3 128 128 none +i s32 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip +i s8 s8 r n n n p 480 128 3 3 128 128 bias,relu,clip +i s32 r n n n p 1024 512 515 515 512 512 none +i s8 r n n n p 1024 512 515 515 512 512 none +i s32 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +i s8 s8 r n n n p 1024 512 515 515 512 512 bias,relu,clip +i s32 r n n n p 1024 2048 1024 1024 2048 2048 none +i s8 r n n n p 1024 2048 1024 1024 2048 2048 none +i s32 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +i s8 s8 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +i s32 r n n n p 1024 2048 515 515 2048 2048 none +i s8 r n n n p 1024 2048 515 515 2048 2048 none +i s32 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +i s8 s8 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +i s32 r n n n p 1024 1040 515 515 1040 1040 none +i s8 r n n n p 1024 1040 515 515 1040 1040 none +i s32 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +i s8 s8 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +i s32 r n n n p 5 1029 515 515 1029 1029 none +i s8 r n n n p 5 1029 515 515 1029 1029 none +i s32 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +i s8 s8 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +i s32 r n n n p 1024 1029 515 515 1029 1029 none +i s8 r n n n p 1024 1029 515 515 1029 1029 none +i s32 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +i s8 s8 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +i s32 r n n n p 1024 1040 2050 2050 1040 1040 none +i s8 r n n n p 1024 1040 2050 2050 1040 1040 none +i s32 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +i s8 s8 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +i s32 r n n n p 1029 1029 2050 2050 1029 1029 none +i s8 r n n n p 1029 1029 2050 2050 1029 1029 none +i s32 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +i s8 s8 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +i s32 r n n n R 480 646 2050 2050 646 646 none +i s8 r n n n R 480 646 2050 2050 646 646 none +i s32 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 481 646 2050 2050 646 646 none +i s8 r n n n R 481 646 2050 2050 646 646 none +i s32 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 482 646 2050 2050 646 646 none +i s8 r n n n R 482 646 2050 2050 646 646 none +i s32 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 483 646 2050 2050 646 646 none +i s8 r n n n R 483 646 2050 2050 646 646 none +i s32 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 484 646 2050 2050 646 646 none +i s8 r n n n R 484 646 2050 2050 646 646 none +i s32 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 485 646 2050 2050 646 646 none +i s8 r n n n R 485 646 2050 2050 646 646 none +i s32 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +i s8 s8 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +i s32 r n n n R 481 656 2050 2050 656 656 none +i s8 r n n n R 481 656 2050 2050 656 656 none +i s32 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +i s8 s8 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +i s32 r n n n R 482 656 2050 2050 656 656 none +i s8 r n n n R 482 656 2050 2050 656 656 none +i s32 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +i s8 s8 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +i s32 r n n n R 483 656 2050 2050 656 656 none +i s8 r n n n R 483 656 2050 2050 656 656 none +i s32 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +i s8 s8 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +i s32 r n n n R 484 656 2050 2050 656 656 none +i s8 r n n n R 484 656 2050 2050 656 656 none +i s32 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +i s8 s8 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +i s32 r n n n p 485 656 2050 2050 656 656 none +i s8 r n n n p 485 656 2050 2050 656 656 none +i s32 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +i s8 s8 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +i s32 r n n n p 480 672 2050 2050 672 672 none +i s8 r n n n p 480 672 2050 2050 672 672 none +i s32 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 481 672 2050 2050 672 672 none +i s8 r n n n p 481 672 2050 2050 672 672 none +i s32 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 482 672 2050 2050 672 672 none +i s8 r n n n p 482 672 2050 2050 672 672 none +i s32 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 483 672 2050 2050 672 672 none +i s8 r n n n p 483 672 2050 2050 672 672 none +i s32 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 484 672 2050 2050 672 672 none +i s8 r n n n p 484 672 2050 2050 672 672 none +i s32 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 485 672 2050 2050 672 672 none +i s8 r n n n p 485 672 2050 2050 672 672 none +i s32 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +i s8 s8 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +i s32 r n n n p 480 688 2050 2050 688 688 none +i s8 r n n n p 480 688 2050 2050 688 688 none +i s32 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n p 481 688 2050 2050 688 688 none +i s8 r n n n p 481 688 2050 2050 688 688 none +i s32 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n r 482 688 2050 2050 688 688 none +i s8 r n n n r 482 688 2050 2050 688 688 none +i s32 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n r 483 688 2050 2050 688 688 none +i s8 r n n n r 483 688 2050 2050 688 688 none +i s32 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n r 484 688 2050 2050 688 688 none +i s8 r n n n r 484 688 2050 2050 688 688 none +i s32 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n r 485 688 2050 2050 688 688 none +i s8 r n n n r 485 688 2050 2050 688 688 none +i s32 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +i s8 s8 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +i s32 r n n n r 1024 512 64 64 512 512 none +i s8 r n n n r 1024 512 64 64 512 512 none +i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s32 r n n n r 16 256 512 512 256 256 none +i s8 r n n n r 16 256 512 512 256 256 none +i s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +i s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +i s32 r n n n r 480 640 512 512 640 640 none +i s8 r n n n r 480 640 512 512 640 640 none +i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s32 r n n n r 64 768 512 512 768 768 none +i s8 r n n n r 64 768 512 512 768 768 none +i s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +i s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +i s32 r n n n r 128 128 128 128 128 128 none +i s8 r n n n r 128 128 128 128 128 128 none +i s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +i s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +i s32 r n n n r 1024 64 512 512 64 64 none +i s8 r n n n r 1024 64 512 512 64 64 none +i s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +i s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +i s32 r n n n r 1024 256 32 32 256 256 none +i s8 r n n n r 1024 256 32 32 256 256 none +i s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +i s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +i s32 r n n n r 1024 512 64 64 512 512 none +i s8 r n n n r 1024 512 64 64 512 512 none +i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s32 r n n n r 480 640 512 512 640 640 none +i s8 r n n n r 480 640 512 512 640 640 none +i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s32 r n n n p 1024 32 256 256 32 32 none +i s8 r n n n p 1024 32 256 256 32 32 none +i s32 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +i s8 s8 r n n n p 1024 32 256 256 32 32 bias,relu,clip +i s32 r n n n P 1024 64 512 512 64 64 none +i s8 r n n n P 1024 64 512 512 64 64 none +i s32 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +i s8 s8 r n n n P 1024 64 512 512 64 64 bias,relu,clip +i s32 r n n n P 64 800 320 320 800 800 none +i s8 r n n n P 64 800 320 320 800 800 none +i s32 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip +i s8 s8 r n n n P 64 800 320 320 800 800 bias,relu,clip +i s32 r n n n P 64 768 512 512 768 768 none +i s8 r n n n P 64 768 512 512 768 768 none +i s32 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip +i s8 s8 r n n n P 64 768 512 512 768 768 bias,relu,clip +i s32 r n n n P 16 256 512 512 256 256 none +i s8 r n n n P 16 256 512 512 256 256 none +i s32 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip +i s8 s8 r n n n P 16 256 512 512 256 256 bias,relu,clip +i s32 r n n n P 128 128 128 128 128 128 none +i s8 r n n n P 128 128 128 128 128 128 none +i s32 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip +i s8 s8 r n n n P 128 128 128 128 128 128 bias,relu,clip +i s32 r n n n P 256 512 256 256 512 512 none +i s8 r n n n P 256 512 256 256 512 512 none +i s32 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip +i s8 s8 r n n n P 256 512 256 256 512 512 bias,relu,clip +i s32 r n n n P 1024 1024 1024 1024 1024 1024 none +i s8 r n n n P 1024 1024 1024 1024 1024 1024 none +i s32 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +i s8 s8 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +i s32 r n n n P 480 640 1024 1024 640 640 none +i s8 r n n n P 480 640 1024 1024 640 640 none +i s32 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +i s8 s8 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +i s32 r n n n P 480 640 256 256 640 640 none +i s8 r n n n P 480 640 256 256 640 640 none +i s32 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip +i s8 s8 r n n n P 480 640 256 256 640 640 bias,relu,clip +i s32 r n n n P 8 64 32 32 64 64 none +i s8 r n n n P 8 64 32 32 64 64 none +i s32 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip +i s8 s8 r n n n P 8 64 32 32 64 64 bias,relu,clip +i s32 r n n n P 9 64 32 32 64 64 none +i s8 r n n n P 9 64 32 32 64 64 none +i s32 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip +i s8 s8 r n n n P 9 64 32 32 64 64 bias,relu,clip +i s32 r n n n P 10 128 64 64 128 128 none +i s8 r n n n P 10 128 64 64 128 128 none +i s32 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip +i s8 s8 r n n n P 10 128 64 64 128 128 bias,relu,clip +i s32 r n n n P 8 8 8 8 8 8 none +i s8 r n n n P 8 8 8 8 8 8 none +i s32 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip +i s8 s8 r n n n P 8 8 8 8 8 8 bias,relu,clip +i s32 r n n n P 12 12 12 12 12 12 none +i s8 r n n n P 12 12 12 12 12 12 none +i s32 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip +i s8 s8 r n n n P 12 12 12 12 12 12 bias,relu,clip +i s32 r n n n P 25 25 25 25 25 25 none +i s8 r n n n P 25 25 25 25 25 25 none +i s32 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip +i s8 s8 r n n n P 25 25 25 25 25 25 bias,relu,clip +i s32 r n n n P 25 25 20 20 25 25 none +i s8 r n n n P 25 25 20 20 25 25 none +i s32 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip +i s8 s8 r n n n P 25 25 20 20 25 25 bias,relu,clip +f f32 r n n n p 480 20 2050 2050 20 20 none +f f32 f32 r n n n p 480 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n p 481 20 2050 2050 20 20 none +f f32 f32 r n n n p 481 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n p 482 20 2050 2050 20 20 none +f f32 f32 r n n n p 482 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n p 483 20 2050 2050 20 20 none +f f32 f32 r n n n p 483 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n R 484 20 2050 2050 20 20 none +f f32 f32 r n n n R 484 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n R 485 20 2050 2050 20 20 none +f f32 f32 r n n n R 485 20 2050 2050 20 20 bias,relu,clip +f f32 r n n n R 480 39 2050 2050 39 39 none +f f32 f32 r n n n R 480 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n R 481 39 2050 2050 39 39 none +f f32 f32 r n n n R 481 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n R 482 39 2050 2050 39 39 none +f f32 f32 r n n n R 482 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n R 483 39 2050 2050 39 39 none +f f32 f32 r n n n R 483 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n R 484 39 2050 2050 39 39 none +f f32 f32 r n n n R 484 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n p 485 39 2050 2050 39 39 none +f f32 f32 r n n n p 485 39 2050 2050 39 39 bias,relu,clip +f f32 r n n n p 480 50 2050 2050 50 50 none +f f32 f32 r n n n p 480 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n p 481 50 2050 2050 50 50 none +f f32 f32 r n n n p 481 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n p 482 50 2050 2050 50 50 none +f f32 f32 r n n n p 482 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n p 483 50 2050 2050 50 50 none +f f32 f32 r n n n p 483 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n p 484 50 2050 2050 50 50 none +f f32 f32 r n n n p 484 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n p 485 50 2050 2050 50 50 none +f f32 f32 r n n n p 485 50 2050 2050 50 50 bias,relu,clip +f f32 r n n n R 480 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 480 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 481 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 481 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 482 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 482 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 483 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 483 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 484 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 484 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 485 1108 2050 2050 1108 1108 none +f f32 f32 r n n n R 485 1108 2050 2050 1108 1108 bias,relu,clip +f f32 r n n n R 480 1127 2050 2050 1127 1127 none +f f32 f32 r n n n R 480 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n R 481 1127 2050 2050 1127 1127 none +f f32 f32 r n n n R 481 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n R 482 1127 2050 2050 1127 1127 none +f f32 f32 r n n n R 482 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n R 483 1127 2050 2050 1127 1127 none +f f32 f32 r n n n R 483 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n p 484 1127 2050 2050 1127 1127 none +f f32 f32 r n n n p 484 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n p 485 1127 2050 2050 1127 1127 none +f f32 f32 r n n n p 485 1127 2050 2050 1127 1127 bias,relu,clip +f f32 r n n n p 480 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 480 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 481 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 481 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 482 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 482 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 483 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 483 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 484 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 484 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 485 1138 2050 2050 1138 1138 none +f f32 f32 r n n n p 485 1138 2050 2050 1138 1138 bias,relu,clip +f f32 r n n n p 1 1 3 3 1 1 none +f f32 f32 r n n n p 1 1 3 3 1 1 bias,relu,clip +f f32 r n n n p 1 9 3 3 9 9 none +f f32 f32 r n n n p 1 9 3 3 9 9 bias,relu,clip +f f32 r n n n p 1 2048 3 3 2048 2048 none +f f32 f32 r n n n p 1 2048 3 3 2048 2048 bias,relu,clip +f f32 r n n n p 1 2048 5192 5192 2048 2048 none +f f32 f32 r n n n p 1 2048 5192 5192 2048 2048 bias,relu,clip +f f32 r n n n p 9 1 3 3 1 1 none +f f32 f32 r n n n p 9 1 3 3 1 1 bias,relu,clip +f f32 r n n n p 576 1 3500 3500 1 1 none +f f32 f32 r n n n p 576 1 3500 3500 1 1 bias,relu,clip +f f32 r n n n p 1 1 1 1 1 1 none +f f32 f32 r n n n p 1 1 1 1 1 1 bias,relu,clip +f f32 r n n n p 102 1088 1024 1024 1088 1088 none +f f32 f32 r n n n p 102 1088 1024 1024 1088 1088 bias,relu,clip +f f32 r n n n p 102 2048 1024 1024 2048 2048 none +f f32 f32 r n n n p 102 2048 1024 1024 2048 2048 bias,relu,clip +f f32 r n n n p 485 656 1024 1024 656 656 none +f f32 f32 r n n n p 485 656 1024 1024 656 656 bias,relu,clip +f f32 r n n n p 483 656 1024 1024 656 656 none +f f32 f32 r n n n p 483 656 1024 1024 656 656 bias,relu,clip +f f32 r n n n p 81 128 3 3 128 128 none +f f32 f32 r n n n p 81 128 3 3 128 128 bias,relu,clip +f f32 r n n n p 1022 512 515 515 512 512 none +f f32 f32 r n n n p 1022 512 515 515 512 512 bias,relu,clip +f f32 r n n n p 74 512 515 515 512 512 none +f f32 f32 r n n n p 74 512 515 515 512 512 bias,relu,clip +f f32 r n n n p 253 2048 515 515 2048 2048 none +f f32 f32 r n n n p 253 2048 515 515 2048 2048 bias,relu,clip +f f32 r n n n p 8192 1040 515 515 1040 1040 none +f f32 f32 r n n n p 8192 1040 515 515 1040 1040 bias,relu,clip +f f32 r n n n p 10 1029 515 515 1029 1029 none +f f32 f32 r n n n p 10 1029 515 515 1029 1029 bias,relu,clip +f f32 r n n n p 24 1040 2050 2050 1040 1040 none +f f32 f32 r n n n p 24 1040 2050 2050 1040 1040 bias,relu,clip +f f32 r n n n p 1024 1029 2050 2050 1029 1029 none +f f32 f32 r n n n p 1024 1029 2050 2050 1029 1029 bias,relu,clip +f f32 r n n n p 480 660 2050 2050 660 660 none +f f32 f32 r n n n p 480 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 481 660 2050 2050 660 660 none +f f32 f32 r n n n p 481 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 482 660 2050 2050 660 660 none +f f32 f32 r n n n p 482 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 483 660 2050 2050 660 660 none +f f32 f32 r n n n p 483 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 484 660 2050 2050 660 660 none +f f32 f32 r n n n p 484 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 485 660 2050 2050 660 660 none +f f32 f32 r n n n p 485 660 2050 2050 660 660 bias,relu,clip +f f32 r n n n p 480 679 2050 2050 679 679 none +f f32 f32 r n n n p 480 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 481 679 2050 2050 679 679 none +f f32 f32 r n n n p 481 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 482 679 2050 2050 679 679 none +f f32 f32 r n n n p 482 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 483 679 2050 2050 679 679 none +f f32 f32 r n n n p 483 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 484 679 2050 2050 679 679 none +f f32 f32 r n n n p 484 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 485 679 2050 2050 679 679 none +f f32 f32 r n n n p 485 679 2050 2050 679 679 bias,relu,clip +f f32 r n n n p 480 690 2050 2050 690 690 none +f f32 f32 r n n n p 480 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 481 690 2050 2050 690 690 none +f f32 f32 r n n n p 481 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 482 690 2050 2050 690 690 none +f f32 f32 r n n n p 482 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 483 690 2050 2050 690 690 none +f f32 f32 r n n n p 483 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 484 690 2050 2050 690 690 none +f f32 f32 r n n n p 484 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 485 690 2050 2050 690 690 none +f f32 f32 r n n n p 485 690 2050 2050 690 690 bias,relu,clip +f f32 r n n n p 480 660 2048 2048 660 660 none +f f32 f32 r n n n p 480 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 481 660 2048 2048 660 660 none +f f32 f32 r n n n p 481 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 482 660 2048 2048 660 660 none +f f32 f32 r n n n p 482 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 483 660 2048 2048 660 660 none +f f32 f32 r n n n p 483 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 484 660 2048 2048 660 660 none +f f32 f32 r n n n p 484 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 485 660 2048 2048 660 660 none +f f32 f32 r n n n p 485 660 2048 2048 660 660 bias,relu,clip +f f32 r n n n p 480 679 2048 2048 679 679 none +f f32 f32 r n n n p 480 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 481 679 2048 2048 679 679 none +f f32 f32 r n n n p 481 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 482 679 2048 2048 679 679 none +f f32 f32 r n n n p 482 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 483 679 2048 2048 679 679 none +f f32 f32 r n n n p 483 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 484 679 2048 2048 679 679 none +f f32 f32 r n n n p 484 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 485 679 2048 2048 679 679 none +f f32 f32 r n n n p 485 679 2048 2048 679 679 bias,relu,clip +f f32 r n n n p 480 690 2048 2048 690 690 none +f f32 f32 r n n n p 480 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 481 690 2048 2048 690 690 none +f f32 f32 r n n n p 481 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 482 690 2048 2048 690 690 none +f f32 f32 r n n n p 482 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 483 690 2048 2048 690 690 none +f f32 f32 r n n n p 483 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 484 690 2048 2048 690 690 none +f f32 f32 r n n n p 484 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 485 690 2048 2048 690 690 none +f f32 f32 r n n n p 485 690 2048 2048 690 690 bias,relu,clip +f f32 r n n n p 480 656 1024 1024 656 656 none +f f32 f32 r n n n p 480 656 1024 1024 656 656 bias,relu,clip +f f32 r n n n p 480 128 3 3 128 128 none +f f32 f32 r n n n p 480 128 3 3 128 128 bias,relu,clip +f f32 r n n n p 1024 512 515 515 512 512 none +f f32 f32 r n n n p 1024 512 515 515 512 512 bias,relu,clip +f f32 r n n n p 1024 2048 1024 1024 2048 2048 none +f f32 f32 r n n n p 1024 2048 1024 1024 2048 2048 bias,relu,clip +f f32 r n n n p 1024 2048 515 515 2048 2048 none +f f32 f32 r n n n p 1024 2048 515 515 2048 2048 bias,relu,clip +f f32 r n n n p 1024 1040 515 515 1040 1040 none +f f32 f32 r n n n p 1024 1040 515 515 1040 1040 bias,relu,clip +f f32 r n n n p 5 1029 515 515 1029 1029 none +f f32 f32 r n n n p 5 1029 515 515 1029 1029 bias,relu,clip +f f32 r n n n p 1024 1029 515 515 1029 1029 none +f f32 f32 r n n n p 1024 1029 515 515 1029 1029 bias,relu,clip +f f32 r n n n p 1024 1040 2050 2050 1040 1040 none +f f32 f32 r n n n p 1024 1040 2050 2050 1040 1040 bias,relu,clip +f f32 r n n n p 1029 1029 2050 2050 1029 1029 none +f f32 f32 r n n n p 1029 1029 2050 2050 1029 1029 bias,relu,clip +f f32 r n n n R 480 646 2050 2050 646 646 none +f f32 f32 r n n n R 480 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 481 646 2050 2050 646 646 none +f f32 f32 r n n n R 481 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 482 646 2050 2050 646 646 none +f f32 f32 r n n n R 482 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 483 646 2050 2050 646 646 none +f f32 f32 r n n n R 483 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 484 646 2050 2050 646 646 none +f f32 f32 r n n n R 484 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 485 646 2050 2050 646 646 none +f f32 f32 r n n n R 485 646 2050 2050 646 646 bias,relu,clip +f f32 r n n n R 481 656 2050 2050 656 656 none +f f32 f32 r n n n R 481 656 2050 2050 656 656 bias,relu,clip +f f32 r n n n R 482 656 2050 2050 656 656 none +f f32 f32 r n n n R 482 656 2050 2050 656 656 bias,relu,clip +f f32 r n n n R 483 656 2050 2050 656 656 none +f f32 f32 r n n n R 483 656 2050 2050 656 656 bias,relu,clip +f f32 r n n n R 484 656 2050 2050 656 656 none +f f32 f32 r n n n R 484 656 2050 2050 656 656 bias,relu,clip +f f32 r n n n p 485 656 2050 2050 656 656 none +f f32 f32 r n n n p 485 656 2050 2050 656 656 bias,relu,clip +f f32 r n n n p 480 672 2050 2050 672 672 none +f f32 f32 r n n n p 480 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 481 672 2050 2050 672 672 none +f f32 f32 r n n n p 481 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 482 672 2050 2050 672 672 none +f f32 f32 r n n n p 482 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 483 672 2050 2050 672 672 none +f f32 f32 r n n n p 483 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 484 672 2050 2050 672 672 none +f f32 f32 r n n n p 484 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 485 672 2050 2050 672 672 none +f f32 f32 r n n n p 485 672 2050 2050 672 672 bias,relu,clip +f f32 r n n n p 480 688 2050 2050 688 688 none +f f32 f32 r n n n p 480 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n p 481 688 2050 2050 688 688 none +f f32 f32 r n n n p 481 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n r 482 688 2050 2050 688 688 none +f f32 f32 r n n n r 482 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n r 483 688 2050 2050 688 688 none +f f32 f32 r n n n r 483 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n r 484 688 2050 2050 688 688 none +f f32 f32 r n n n r 484 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n r 485 688 2050 2050 688 688 none +f f32 f32 r n n n r 485 688 2050 2050 688 688 bias,relu,clip +f f32 r n n n r 1024 512 64 64 512 512 none +f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip +f f32 r n n n r 16 256 512 512 256 256 none +f f32 f32 r n n n r 16 256 512 512 256 256 bias,relu,clip +f f32 r n n n r 480 640 512 512 640 640 none +f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip +f f32 r n n n r 64 768 512 512 768 768 none +f f32 f32 r n n n r 64 768 512 512 768 768 bias,relu,clip +f f32 r n n n r 128 128 128 128 128 128 none +f f32 f32 r n n n r 128 128 128 128 128 128 bias,relu,clip +f f32 r n n n r 1024 64 512 512 64 64 none +f f32 f32 r n n n r 1024 64 512 512 64 64 bias,relu,clip +f f32 r n n n r 1024 256 32 32 256 256 none +f f32 f32 r n n n r 1024 256 32 32 256 256 bias,relu,clip +f f32 r n n n r 1024 512 64 64 512 512 none +f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip +f f32 r n n n r 480 640 512 512 640 640 none +f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip +f f32 r n n n p 1024 32 256 256 32 32 none +f f32 f32 r n n n p 1024 32 256 256 32 32 bias,relu,clip +f f32 r n n n P 1024 64 512 512 64 64 none +f f32 f32 r n n n P 1024 64 512 512 64 64 bias,relu,clip +f f32 r n n n P 64 800 320 320 800 800 none +f f32 f32 r n n n P 64 800 320 320 800 800 bias,relu,clip +f f32 r n n n P 64 768 512 512 768 768 none +f f32 f32 r n n n P 64 768 512 512 768 768 bias,relu,clip +f f32 r n n n P 16 256 512 512 256 256 none +f f32 f32 r n n n P 16 256 512 512 256 256 bias,relu,clip +f f32 r n n n P 128 128 128 128 128 128 none +f f32 f32 r n n n P 128 128 128 128 128 128 bias,relu,clip +f f32 r n n n P 256 512 256 256 512 512 none +f f32 f32 r n n n P 256 512 256 256 512 512 bias,relu,clip +f f32 r n n n P 1024 1024 1024 1024 1024 1024 none +f f32 f32 r n n n P 1024 1024 1024 1024 1024 1024 bias,relu,clip +f f32 r n n n P 480 640 1024 1024 640 640 none +f f32 f32 r n n n P 480 640 1024 1024 640 640 bias,relu,clip +f f32 r n n n P 480 640 256 256 640 640 none +f f32 f32 r n n n P 480 640 256 256 640 640 bias,relu,clip +f f32 r n n n P 8 64 32 32 64 64 none +f f32 f32 r n n n P 8 64 32 32 64 64 bias,relu,clip +f f32 r n n n P 9 64 32 32 64 64 none +f f32 f32 r n n n P 9 64 32 32 64 64 bias,relu,clip +f f32 r n n n P 10 128 64 64 128 128 none +f f32 f32 r n n n P 10 128 64 64 128 128 bias,relu,clip +f f32 r n n n P 8 8 8 8 8 8 none +f f32 f32 r n n n P 8 8 8 8 8 8 bias,relu,clip +f f32 r n n n P 12 12 12 12 12 12 none +f f32 f32 r n n n P 12 12 12 12 12 12 bias,relu,clip +f f32 r n n n P 25 25 25 25 25 25 none +f f32 f32 r n n n P 25 25 25 25 25 25 bias,relu,clip +f f32 r n n n P 25 25 20 20 25 25 none +f f32 f32 r n n n P 25 25 20 20 25 25 bias,relu,clip +i s32 r n n n r 4096 256 5 5 256 256 none +i s8 r n n n r 4096 256 5 5 256 256 none +i s32 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip +i s8 s8 r n n n r 4096 256 5 5 256 256 bias,relu,clip +i s32 r n n n r 3000 256 128 128 256 256 none +i s8 r n n n r 3000 256 128 128 256 256 none +i s32 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip +i s8 s8 r n n n r 3000 256 128 128 256 256 bias,relu,clip +i s32 r n n n r 4096 1024 512 512 1024 1024 none +i s8 r n n n r 4096 1024 512 512 1024 1024 none +i s32 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip +i s8 s8 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip +i s32 r n n n r 144 256 5 5 256 256 none +i s8 r n n n r 144 256 5 5 256 256 none +i s32 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip +i s8 s8 r n n n r 144 256 5 5 256 256 bias,relu,clip +i s32 r n n n r 144 256 128 128 256 256 none +i s8 r n n n r 144 256 128 128 256 256 none +i s32 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip +i s8 s8 r n n n r 144 256 128 128 256 256 bias,relu,clip +i s32 r n n n r 144 1024 512 512 1024 1024 none +i s8 r n n n r 144 1024 512 512 1024 1024 none +i s32 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip +i s8 s8 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip +i s32 r n n n r 480 688 256 256 688 688 none +i s8 r n n n r 480 688 256 256 688 688 none +i s32 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip +i s8 s8 r n n n r 480 688 256 256 688 688 bias,relu,clip +i s32 r n n n r 480 640 512 512 640 640 none +i s8 r n n n r 480 640 512 512 640 640 none +i s32 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s8 s8 r n n n r 480 640 512 512 640 640 bias,relu,clip +i s32 r n n n r 480 640 1024 1024 640 640 none +i s8 r n n n r 480 640 1024 1024 640 640 none +i s32 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip +i s8 s8 r n n n r 480 640 1024 1024 640 640 bias,relu,clip +i s32 r n n n r 64 800 320 320 800 800 none +i s8 r n n n r 64 800 320 320 800 800 none +i s32 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip +i s8 s8 r n n n r 64 800 320 320 800 800 bias,relu,clip +i s32 r n n n r 64 768 512 512 768 768 none +i s8 r n n n r 64 768 512 512 768 768 none +i s32 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +i s8 s8 r n n n r 64 768 512 512 768 768 bias,relu,clip +i s32 r n n n r 16 256 512 512 256 256 none +i s8 r n n n r 16 256 512 512 256 256 none +i s32 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +i s8 s8 r n n n r 16 256 512 512 256 256 bias,relu,clip +i s32 r n n n r 128 128 128 128 128 128 none +i s8 r n n n r 128 128 128 128 128 128 none +i s32 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +i s8 s8 r n n n r 128 128 128 128 128 128 bias,relu,clip +i s32 r n n n r 256 512 256 256 512 512 none +i s8 r n n n r 256 512 256 256 512 512 none +i s32 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip +i s8 s8 r n n n r 256 512 256 256 512 512 bias,relu,clip +i s32 r n n n r 1024 1024 1024 1024 1024 1024 none +i s8 r n n n r 1024 1024 1024 1024 1024 1024 none +i s32 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip +i s8 s8 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip +i s32 r n n n r 1024 32 256 256 32 32 none +i s8 r n n n r 1024 32 256 256 32 32 none +i s32 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip +i s8 s8 r n n n r 1024 32 256 256 32 32 bias,relu,clip +i s32 r n n n r 1024 64 512 512 64 64 none +i s8 r n n n r 1024 64 512 512 64 64 none +i s32 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +i s8 s8 r n n n r 1024 64 512 512 64 64 bias,relu,clip +i s32 r n n n r 1024 256 32 32 256 256 none +i s8 r n n n r 1024 256 32 32 256 256 none +i s32 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +i s8 s8 r n n n r 1024 256 32 32 256 256 bias,relu,clip +i s32 r n n n r 1024 512 64 64 512 512 none +i s8 r n n n r 1024 512 64 64 512 512 none +i s32 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s8 s8 r n n n r 1024 512 64 64 512 512 bias,relu,clip +i s32 r n n n r 512 32 256 256 32 32 none +i s8 r n n n r 512 32 256 256 32 32 none +i s32 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip +i s8 s8 r n n n r 512 32 256 256 32 32 bias,relu,clip +i s32 r n n n r 512 768 512 512 768 768 none +i s8 r n n n r 512 768 512 512 768 768 none +i s32 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip +i s8 s8 r n n n r 512 768 512 512 768 768 bias,relu,clip +i s32 r n n n r 512 256 32 32 256 256 none +i s8 r n n n r 512 256 32 32 256 256 none +i s32 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip +i s8 s8 r n n n r 512 256 32 32 256 256 bias,relu,clip +i s32 r n n n r 512 512 64 64 512 512 none +i s8 r n n n r 512 512 64 64 512 512 none +i s32 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip +i s8 s8 r n n n r 512 512 64 64 512 512 bias,relu,clip +i s32 r n n n r 512 256 768 768 256 256 none +i s8 r n n n r 512 256 768 768 256 256 none +i s32 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip +i s8 s8 r n n n r 512 256 768 768 256 256 bias,relu,clip +i s32 r n n n r 768 768 1024 1024 768 768 none +i s8 r n n n r 768 768 1024 1024 768 768 none +i s32 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip +i s8 s8 r n n n r 768 768 1024 1024 768 768 bias,relu,clip +i s32 r n n n r 768 768 768 768 768 768 none +i s8 r n n n r 768 768 768 768 768 768 none +i s32 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip +i s8 s8 r n n n r 768 768 768 768 768 768 bias,relu,clip +i s32 r n n n r 2048 2048 2048 2048 2048 2048 none +i s8 r n n n r 2048 2048 2048 2048 2048 2048 none +i s32 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip +i s8 s8 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip +i s32 r n n n r 4096 4096 4096 4096 4096 4096 none +i s8 r n n n r 4096 4096 4096 4096 4096 4096 none +i s32 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip +i s8 s8 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip +f f32 r n n n r 4096 256 5 5 256 256 none +f f32 f32 r n n n r 4096 256 5 5 256 256 bias,relu,clip +f f32 r n n n r 3000 256 128 128 256 256 none +f f32 f32 r n n n r 3000 256 128 128 256 256 bias,relu,clip +f f32 r n n n r 4096 1024 512 512 1024 1024 none +f f32 f32 r n n n r 4096 1024 512 512 1024 1024 bias,relu,clip +f f32 r n n n r 144 256 5 5 256 256 none +f f32 f32 r n n n r 144 256 5 5 256 256 bias,relu,clip +f f32 r n n n r 144 256 128 128 256 256 none +f f32 f32 r n n n r 144 256 128 128 256 256 bias,relu,clip +f f32 r n n n r 144 1024 512 512 1024 1024 none +f f32 f32 r n n n r 144 1024 512 512 1024 1024 bias,relu,clip +f f32 r n n n r 480 688 256 256 688 688 none +f f32 f32 r n n n r 480 688 256 256 688 688 bias,relu,clip +f f32 r n n n r 480 640 512 512 640 640 none +f f32 f32 r n n n r 480 640 512 512 640 640 bias,relu,clip +f f32 r n n n r 480 640 1024 1024 640 640 none +f f32 f32 r n n n r 480 640 1024 1024 640 640 bias,relu,clip +f f32 r n n n r 64 800 320 320 800 800 none +f f32 f32 r n n n r 64 800 320 320 800 800 bias,relu,clip +f f32 r n n n r 64 768 512 512 768 768 none +f f32 f32 r n n n r 64 768 512 512 768 768 bias,relu,clip +f f32 r n n n r 16 256 512 512 256 256 none +f f32 f32 r n n n r 16 256 512 512 256 256 bias,relu,clip +f f32 r n n n r 128 128 128 128 128 128 none +f f32 f32 r n n n r 128 128 128 128 128 128 bias,relu,clip +f f32 r n n n r 256 512 256 256 512 512 none +f f32 f32 r n n n r 256 512 256 256 512 512 bias,relu,clip +f f32 r n n n r 1024 1024 1024 1024 1024 1024 none +f f32 f32 r n n n r 1024 1024 1024 1024 1024 1024 bias,relu,clip +f f32 r n n n r 1024 32 256 256 32 32 none +f f32 f32 r n n n r 1024 32 256 256 32 32 bias,relu,clip +f f32 r n n n r 1024 64 512 512 64 64 none +f f32 f32 r n n n r 1024 64 512 512 64 64 bias,relu,clip +f f32 r n n n r 1024 256 32 32 256 256 none +f f32 f32 r n n n r 1024 256 32 32 256 256 bias,relu,clip +f f32 r n n n r 1024 512 64 64 512 512 none +f f32 f32 r n n n r 1024 512 64 64 512 512 bias,relu,clip +f f32 r n n n r 512 32 256 256 32 32 none +f f32 f32 r n n n r 512 32 256 256 32 32 bias,relu,clip +f f32 r n n n r 512 768 512 512 768 768 none +f f32 f32 r n n n r 512 768 512 512 768 768 bias,relu,clip +f f32 r n n n r 512 256 32 32 256 256 none +f f32 f32 r n n n r 512 256 32 32 256 256 bias,relu,clip +f f32 r n n n r 512 512 64 64 512 512 none +f f32 f32 r n n n r 512 512 64 64 512 512 bias,relu,clip +f f32 r n n n r 512 256 768 768 256 256 none +f f32 f32 r n n n r 512 256 768 768 256 256 bias,relu,clip +f f32 r n n n r 768 768 1024 1024 768 768 none +f f32 f32 r n n n r 768 768 1024 1024 768 768 bias,relu,clip +f f32 r n n n r 768 768 768 768 768 768 none +f f32 f32 r n n n r 768 768 768 768 768 768 bias,relu,clip +f f32 r n n n r 2048 2048 2048 2048 2048 2048 none +f f32 f32 r n n n r 2048 2048 2048 2048 2048 2048 bias,relu,clip +f f32 r n n n r 4096 4096 4096 4096 4096 4096 none +f f32 f32 r n n n r 4096 4096 4096 4096 4096 4096 bias,relu,clip +f f32 r n n n r 2048 1024 1024 1024 1024 1024 none +f f32 f32 r n n n r 2048 1024 1024 1024 1024 1024 bias,relu,clip +f f32 r n n n r 2048 4096 1024 1024 4096 4096 none +f f32 f32 r n n n r 2048 4096 1024 1024 4096 4096 bias,relu,clip +f f32 r n n n r 2048 1024 4096 4096 1024 1024 none +f f32 f32 r n n n r 2048 1024 4096 4096 1024 1024 bias,relu,clip +f f32 r n n n r 2048 1024 2 2 1024 1024 none +f f32 f32 r n n n r 2048 1024 2 2 1024 1024 bias,relu,clip +f f32 r n n n r 128 1024 1024 1024 1024 1024 none +f f32 f32 r n n n r 128 1024 1024 1024 1024 1024 bias,relu,clip +f f32 r n n n r 1536 768 768 768 768 768 none +f f32 f32 r n n n r 1536 768 768 768 768 768 bias,relu,clip +f f32 r n n n r 1536 3072 768 768 3072 3072 none +f f32 f32 r n n n r 1536 3072 768 768 3072 3072 bias,relu,clip +f f32 r n n n r 1536 768 3072 3072 768 768 none +f f32 f32 r n n n r 1536 768 3072 3072 768 768 bias,relu,clip +f f32 r n n n r 1536 768 2 2 768 768 none +f f32 f32 r n n n r 1536 768 2 2 768 768 bias,relu,clip +f f32 r n n n r 128 768 768 768 768 768 none +f f32 f32 r n n n r 128 768 768 768 768 768 bias,relu,clip +f f32 r n n n r 1024 8 13 13 8 8 none +f f32 f32 r n n n r 1024 8 13 13 8 8 bias,relu,clip +f f32 r n n n r 1024 4 8 8 4 4 none +f f32 f32 r n n n r 1024 4 8 8 4 4 bias,relu,clip +f f32 r n n n r 1024 128 355 355 128 128 none +f f32 f32 r n n n r 1024 128 355 355 128 128 bias,relu,clip +f f32 r n n n r 1024 64 128 128 64 64 none +f f32 f32 r n n n r 1024 64 128 128 64 64 bias,relu,clip +f f32 r n n n r 1024 1 64 64 1 1 none +f f32 f32 r n n n r 1024 1 64 64 1 1 bias,relu,clip +f f32 r n n n r 480 1 256 256 1 1 none +f f32 f32 r n n n r 480 1 256 256 1 1 bias,relu,clip +f f32 r n n n r 480 256 512 512 256 256 none +f f32 f32 r n n n r 480 256 512 512 256 256 bias,relu,clip +f f32 r n n n r 480 1024 845 845 1024 1024 none +f f32 f32 r n n n r 480 1024 845 845 1024 1024 bias,relu,clip +f f32 r n n n r 480 512 1024 1024 512 512 none +f f32 f32 r n n n r 480 512 1024 1024 512 512 bias,relu,clip +f f32 r n n n r 10 17191 128 128 17191 17191 none +f f32 f32 r n n n r 10 17191 128 128 17191 17191 bias,relu,clip +f f32 r n n n r 10 512 256 256 512 512 none +f f32 f32 r n n n r 10 512 256 256 512 512 bias,relu,clip diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 7dd049b159..bb70a087b2 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -43,8 +43,11 @@ #include "blis.h" -#define S8_MIN (-128) -#define S8_MAX (+127) + +// Used to clip downscaled output, will be set in the main loop based +// on the accumulation and C data type. +int64_t DSCALE_CLIP_MIN = 0; +int64_t DSCALE_CLIP_MAX = 0; // Mode can be one of the follwoing: // 1. p - performance, used for benchmarks. @@ -63,40 +66,139 @@ dim_t num_eltwise = 0; // To keep track of eltwise operations. #define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype -inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +static inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +{ + /*Set offset 2 to copy most significant 2 bytes of float + to convert float values to bf16 values*/ + memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); +} + +static inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) +{ + for (int i=0; i< size; i++) + { + float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + } +} + + +static inline void bfloat16_to_float( bfloat16 bf16_val, float* float_val ) { - /*Set offset 2 to copy most significant 2 bytes of float - to convert float values to bf16 values*/ - memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); + int32_t inter_temp = *( ( int16_t* ) &bf16_val ); + inter_temp = inter_temp << 16; + memcpy( float_val, &inter_temp, sizeof( int32_t ) ); } -inline float bf16_to_float +#define CONVERT_TO_FLOAT(ctype) \ +static inline void GEN_FUNC_NAME(ctype,_to_float) ( ctype val, float* float_val ) \ +{ \ + *float_val = (float) val; \ +} \ + +CONVERT_TO_FLOAT(uint8_t) +CONVERT_TO_FLOAT(int8_t) +CONVERT_TO_FLOAT(int16_t) +CONVERT_TO_FLOAT(float) +CONVERT_TO_FLOAT(int32_t) + + + +/* Helper functions to print matrices when debugging */ +void print_matrix_bfloat16 ( - bfloat16 bf16_val + bfloat16* a, + dim_t m, + dim_t n, + dim_t rs_a, + dim_t cs_a ) { - int32_t inter_temp = *( ( int16_t* ) &bf16_val ); - inter_temp = inter_temp << 16; - float float_value = *( float* ) ( &inter_temp ); - return float_value; + for(dim_t i = 0; i < m; i++) + { + for(dim_t j = 0; j < n; j++) + { + float temp; + bfloat16_to_float(*(a + i*(rs_a) + j *cs_a), &temp); + printf("%f ", temp); + } + printf("\n"); + } +} + +#define PRINT_MATRIX(ctype) \ +void print_matrix_## ctype ( ctype* a, int32_t m, int32_t n, int32_t rs, int32_t cs) \ +{ \ + for(int32_t i = 0; i < m; i++) \ + { \ + for(int32_t j = 0; j < n; j++) \ + { \ + printf("%f ", (float) (*(a + i * ( rs ) + j * cs ) ) ); \ + } \ + printf("\n"); \ + } \ +} \ + +PRINT_MATRIX(uint8_t) +PRINT_MATRIX(int8_t) +PRINT_MATRIX(int16_t) +PRINT_MATRIX(float) +PRINT_MATRIX(int32_t) + +void* lpgemm_malloc( int32_t size ) +{ + void* p; + // creating a dummy buffer of size 4 bytes in case + // size of the matrix is negative. + if( size <= 0 ) + { + p = malloc( 4 ); + return p; + } + + if( bench_mode == 'a' ) + { + p = malloc(size); + } + else + { + err_t err = BLIS_SUCCESS; + p = bli_malloc_user(size, &err); + } + if ( p == NULL ) + { + printf("Unable to allocate memory.\n"); + exit(1); + } + return p; } -inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) +void lpgemm_free( void* p ) { - for (int i=0; i< size; i++) - { - float_to_bf16( ( array + i ), ( array_bf16 + i ) ); - } + if( p == NULL) + { + printf("Attempt to free null pointer\n"); + return; + } + + if( bench_mode == 'a' ) + { + free(p); + } + else + { + bli_free_user(p); + } } #define GEN_FILL_ARRAY_FUNC(ctype) \ void fill_array_ ## ctype ( void* arr, dim_t size ) \ { \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 10 ); \ - } \ + if( size < 0 ) return; \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 5 ); \ + } \ } \ GEN_FILL_ARRAY_FUNC(uint8_t) @@ -107,26 +209,28 @@ GEN_FILL_ARRAY_FUNC(int32_t) void fill_array_bfloat16( void* arr, dim_t size ) { - float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size ); - for ( dim_t i = 0; i < size; ++i ) - { - c_float[i] = 2.0; - } - convert_float_arr_to_bf16( c_float, arr, size ); - if ( c_float != NULL ) - { - bli_free_user( c_float ); - } + err_t bli_errors = BLIS_SUCCESS; + if( size < 0 ) return; + float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size, &bli_errors ); + for ( dim_t i = 0; i < size; ++i ) + { + c_float[i] = i % 5; + } + convert_float_arr_to_bf16( c_float, arr, size ); + if ( c_float != NULL ) + { + bli_free_user( c_float ); + } } #define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ { \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 20 ); \ - } \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 20 ); \ + } \ } \ GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) @@ -137,7 +241,10 @@ GEN_FILL_ARRAY_POST_OPS_FUNC(float) void mat_mul_ ## BLAS_SFX \ ( \ char stor_order, \ - char op_t, \ + char transa, \ + char transb, \ + char op_a, \ + char op_b, \ dim_t m, \ dim_t n, \ dim_t k, \ @@ -152,90 +259,72 @@ void mat_mul_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ - char storage = stor_order; \ - char transa = 'n'; \ - char transb = 'n'; \ - char reordera = 'n'; \ - char reorderb = 'n'; \ - \ - if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ - { \ - /* No reordering of B.*/ \ - reordera = 'n'; \ - reorderb = 'n'; \ - } \ - else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ - { \ - /* Reordered B.*/ \ - reordera = 'n'; \ - reorderb = 'r'; \ - } \ + aocl_gemm_ ## BLAS_SFX( stor_order, transa, transb, m, n, k, \ + alpha, \ + a, lda, op_a, \ + b, ldb, op_b, \ + beta, \ + c, ldc, post_op ); \ \ - aocl_gemm_ ## BLAS_SFX( storage, transa, transb, m, n, k, \ - alpha, \ - a, lda, reordera, \ - b, ldb, reorderb, \ - beta, \ - c, ldc, post_op ); \ + /*dim_t MR = 6; \ + dim_t NR = 16; \ \ - /*dim_t MR = 6; \ - dim_t NR = 16; \ + __m512i selector1; \ + __m512i all_zero = _mm512_setzero_epi32(); \ + __m512i c0; \ + __m512i c1; \ + __m512i c2; \ + __m512i c3; \ + __m512i c4; \ + __m512i c5; \ \ - __m512i selector1; \ - __m512i all_zero = _mm512_setzero_epi32(); \ - __m512i c0; \ - __m512i c1; \ - __m512i c2; \ - __m512i c3; \ - __m512i c4; \ - __m512i c5; \ - \ - for ( dim_t i = 0; i < m; i += MR ) \ - { \ - if ( ( i + MR ) > m ) \ - { \ - break; \ - } \ - for ( dim_t j = 0; j < n; j += NR ) \ - { \ - if ( ( j + NR ) > n ) \ - { \ - break; \ - } \ - selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ - c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ - c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ - c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ - c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ - c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ - c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ + for ( dim_t i = 0; i < m; i += MR ) \ + { \ + if ( ( i + MR ) > m ) \ + { \ + break; \ + } \ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + if ( ( j + NR ) > n ) \ + { \ + break; \ + } \ + selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ + c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ + c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ + c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ + c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ + c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ + c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ \ - c0 = _mm512_add_epi32( selector1, c0 ); \ - c1 = _mm512_add_epi32( selector1, c1 ); \ - c2 = _mm512_add_epi32( selector1, c2 ); \ - c3 = _mm512_add_epi32( selector1, c3 ); \ - c4 = _mm512_add_epi32( selector1, c4 ); \ - c5 = _mm512_add_epi32( selector1, c5 ); \ + c0 = _mm512_add_epi32( selector1, c0 ); \ + c1 = _mm512_add_epi32( selector1, c1 ); \ + c2 = _mm512_add_epi32( selector1, c2 ); \ + c3 = _mm512_add_epi32( selector1, c3 ); \ + c4 = _mm512_add_epi32( selector1, c4 ); \ + c5 = _mm512_add_epi32( selector1, c5 ); \ \ - c0 = _mm512_max_epi32( all_zero, c0 ); \ - c1 = _mm512_max_epi32( all_zero, c1 ); \ - c2 = _mm512_max_epi32( all_zero, c2 ); \ - c3 = _mm512_max_epi32( all_zero, c3 ); \ - c4 = _mm512_max_epi32( all_zero, c4 ); \ - c5 = _mm512_max_epi32( all_zero, c5 ); \ + c0 = _mm512_max_epi32( all_zero, c0 ); \ + c1 = _mm512_max_epi32( all_zero, c1 ); \ + c2 = _mm512_max_epi32( all_zero, c2 ); \ + c3 = _mm512_max_epi32( all_zero, c3 ); \ + c4 = _mm512_max_epi32( all_zero, c4 ); \ + c5 = _mm512_max_epi32( all_zero, c5 ); \ \ - _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ - _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ - _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ - _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ - _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ - _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ - } \ - } */\ + _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ + _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ + _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ + _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ + _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ + _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ + } \ + } */\ } \ GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) @@ -254,13 +343,15 @@ double get_gflops double runtime ) { - return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); + return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); } void print_result ( const char* msg, int32_t n_repeats, + char transa, + char transb, dim_t m, dim_t n, dim_t k, @@ -270,17 +361,20 @@ void print_result double runtime ) { - double gflops = get_gflops( m, n, k, runtime ); - printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ - " Gops: %f, n_repeats: %d\n", - msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); + double gflops = get_gflops( m, n, k, runtime ); + printf("%s transa:%c, transb:%c, m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ + " Gops: %f, n_repeats: %d\n", + msg, transa, transb, m, n, k, lda, ldb, ldc, gflops, n_repeats); } #define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \ void mat_mul_bench_driver_ ## BLAS_SFX \ ( \ char stor_order, \ - char op_t, \ + char transa, \ + char transb, \ + char op_a, \ + char op_b, \ int32_t n_repeats, \ dim_t m, \ dim_t n, \ @@ -296,41 +390,43 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ - double min_time_diff = DBL_MAX; \ - for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ - { \ - if ( bench_mode == 'a' ) \ - { \ - GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ - } \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + if ( bench_mode == 'a' ) \ + { \ + int32_t size_C = ( ( stor_order == 'r') || ( stor_order == 'R' ) )? m * ldc : n * ldc; \ + GEN_FUNC_NAME(fill_array_,C_type)( c, ( size_C ) ); \ + } \ \ - struct timespec tstart={0,0}, tend={0,0}; \ - clock_gettime(CLOCK_MONOTONIC, &tstart); \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ \ - GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ - ( \ - stor_order, op_t, m, n, k, \ - alpha, \ - a, lda, \ - b, ldb, \ - beta, \ - c, ldc, \ - post_op \ - ); \ + GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ + ( \ + stor_order, transa, transb, op_a, op_b, m, n, k, \ + alpha, \ + a, lda, \ + b, ldb, \ + beta, \ + c, ldc, \ + post_op \ + ); \ \ - clock_gettime(CLOCK_MONOTONIC, &tend); \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ \ - double diff = \ - ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ - ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ - min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ - } \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ \ - print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ + print_result( XSTR(BLAS_SFX), n_repeats, transa, transb, m, n, k, lda, ldb, ldc, min_time_diff); \ } \ GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) @@ -343,44 +439,50 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) int max (int a, int b) { - return ( a > b ? a : b ); + return ( a > b ? a : b ); } int min (int a, int b) { - return ( a < b ? a : b ); + return ( a < b ? a : b ); } -#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ -inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ +#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ +static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ (\ ACCUM_type temp_accum,\ aocl_post_op* post_op, \ dim_t j \ )\ -{\ - ACCUM_type out_temp_accum = ( ACCUM_type ) min ( max ( nearbyintf( ( SCALE_type )temp_accum * \ - ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ), S8_MIN ), S8_MAX ) ; \ - return out_temp_accum; \ +{ \ + ACCUM_type out_temp_accum = \ + ( ACCUM_type )min( \ + max( nearbyintf( ( SCALE_type )( temp_accum ) * \ + ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ) + \ + *( ( C_type* )post_op->sum.zero_point + j ), \ + DSCALE_CLIP_MIN ), \ + DSCALE_CLIP_MAX ); \ + return out_temp_accum; \ }\ -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int16_t,float,u8s8s16os8) -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int32_t,float,u8s8s32os8) -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int32_t,float,s8s8s32os8) -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int16_t,float,s8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8) -inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16 +static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16 ( float temp_accum, aocl_post_op* post_op, dim_t j ) { - return temp_accum; + return temp_accum; } #define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ -inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ +static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ (\ A_type* a, \ B_type* b, \ @@ -399,18 +501,19 @@ inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ dim_t k \ )\ {\ - for ( dim_t p = 0; p < k; ++p) \ - { \ - temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ - *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ - } \ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ + *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ + } \ \ - temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ - + ( alpha * temp_accum ); \ - return temp_accum; \ + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ + + ( alpha * temp_accum ); \ + return temp_accum; \ }\ GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) @@ -420,7 +523,7 @@ GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) -inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 +static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 ( bfloat16* a, bfloat16* b, @@ -439,18 +542,19 @@ inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 dim_t k ) { - for ( dim_t p = 0; p < k; ++p) - { - float a_float = bf16_to_float( *( a + i * rs_a + p * cs_a ) ); - float b_float = bf16_to_float( *( b + p * rs_b + j * cs_b ) ); - temp_accum += ( ( a_float ) * ( b_float ) ); - } - temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) - + ( alpha * temp_accum ); - return temp_accum; + for ( dim_t p = 0; p < k; ++p) + { + float a_float, b_float; + bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float); + bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) + + ( alpha * temp_accum ); + return temp_accum; } -inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 +static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 ( bfloat16* a, bfloat16* b, @@ -469,32 +573,35 @@ inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 dim_t k ) { - for ( dim_t p = 0; p < k; ++p) - { - float a_float = bf16_to_float( *( a + i*rs_a + p*cs_a ) ); - float b_float = bf16_to_float( *( b + p*rs_b + j*cs_b ) ); - temp_accum += ( ( a_float ) * ( b_float ) ); - } - float c_ref_float = bf16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ) ); - temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); - - return temp_accum; + for ( dim_t p = 0; p < k; ++p) + { + float a_float, b_float; + bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float ); + bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + float c_ref_float; + bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float ); + temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); + + return temp_accum; } #define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \ -inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \ +static inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \ (\ ACCUM_type temp_accum \ )\ {\ - float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ - ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ - (double)temp_accum ) ) ) ) ); \ - temp_accum = round (gelu_reference); \ - return temp_accum; \ + float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ }\ GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8) +GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8) GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16) GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8) GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32) @@ -504,15 +611,15 @@ GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8) GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16) #define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \ -inline float GELU_TANH_post_op_ ## BLAS_SFX \ +static inline float GELU_TANH_post_op_ ## BLAS_SFX \ (\ float temp_accum \ )\ {\ - temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ - ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ - (double)temp_accum ) ) ) ) ); \ - return temp_accum; \ + temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + return temp_accum; \ }\ GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32) @@ -520,17 +627,18 @@ GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32) GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16) #define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \ -inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \ +static inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \ (\ ACCUM_type temp_accum \ )\ {\ - float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ - temp_accum = round (gelu_reference); \ - return temp_accum; \ + float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ }\ GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8) +GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8) GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16) GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8) GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32) @@ -540,13 +648,13 @@ GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8) GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16) #define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \ -inline float GELU_ERF_post_op_ ## BLAS_SFX \ +static inline float GELU_ERF_post_op_ ## BLAS_SFX \ (\ float temp_accum \ )\ {\ - temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ - return temp_accum; \ + temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + return temp_accum; \ }\ GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32) @@ -560,13 +668,14 @@ void mat_mul_get_output_type_val ## ACCUM_type ## C_type \ ACCUM_type* temp_accum \ ) \ { \ - ( *out_temp_accum ) = ( C_type )( *temp_accum ); \ + ( *out_temp_accum ) = ( C_type )( *temp_accum ); \ } \ GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int16_t,int16_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t) +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int16_t) GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float) void mat_mul_get_output_type_valfloatbfloat16 @@ -575,7 +684,7 @@ void mat_mul_get_output_type_valfloatbfloat16 float* temp_accum ) { - float_to_bf16( temp_accum, out_temp_accum ); + float_to_bf16( temp_accum, out_temp_accum ); } #define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \ @@ -583,6 +692,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ( \ FILE* fout, \ const char stor_order, \ + char transa, \ + char transb, \ dim_t m, \ dim_t n, \ dim_t k, \ @@ -599,130 +710,170 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ - dim_t rs_a = lda; \ - dim_t cs_a = 1; \ - dim_t rs_b = ldb; \ - dim_t cs_b = 1; \ - dim_t rs_c = ldc; \ - dim_t cs_c = 1; \ - dim_t rs_c_ref = ldc_ref; \ - dim_t cs_c_ref = 1; \ + dim_t rs_a, cs_a; \ + if( ( transa == 'n' ) || ( transa == 'N' ) ) \ + { \ + rs_a = lda; \ + cs_a = 1; \ + } \ + else \ + { \ + rs_a = 1; \ + cs_a = lda; \ + } \ + dim_t rs_b, cs_b; \ + if( ( transb == 'n' ) || ( transb == 'N' ) ) \ + { \ + rs_b = ldb; \ + cs_b = 1; \ + } \ + else \ + { \ + rs_b = 1; \ + cs_b = ldb; \ + } \ + dim_t rs_c = ldc; \ + dim_t cs_c = 1; \ + dim_t rs_c_ref = ldc_ref; \ + dim_t cs_c_ref = 1; \ \ - if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ - { \ - rs_a = 1; \ - cs_a = lda; \ - rs_b = 1; \ - cs_b = ldb; \ - rs_c = 1; \ - cs_c = ldc; \ - rs_c_ref = 1; \ - cs_c_ref = ldc_ref; \ - } \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + if( transa == 'n' || transa == 'N') \ + { \ + rs_a = 1; \ + cs_a = lda; \ + } \ + else \ + { \ + rs_a = lda; \ + cs_a = 1; \ + } \ + if( ( transb == 'n' ) || ( transb == 'N' ) ) \ + { \ + rs_b = 1; \ + cs_b = ldb; \ + } \ + else \ + { \ + rs_b = ldb; \ + cs_b = 1; \ + } \ + rs_c = 1; \ + cs_c = ldc; \ + rs_c_ref = 1; \ + cs_c_ref = ldc_ref; \ + } \ \ - for ( dim_t i = 0; i < m; ++i ) \ - { \ - for ( dim_t j = 0; j < n; ++j ) \ - { \ - ACCUM_type temp_accum = 0; \ - C_type out_temp_accum = 0; \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ACCUM_type temp_accum = 0; \ + C_type out_temp_accum = 0; \ \ - temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \ - (a,b,c_ref,temp_accum,alpha,beta,rs_a,rs_b,cs_a,cs_b,rs_c_ref,cs_c_ref,i,j,k); \ + temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \ + (a,b,c_ref,temp_accum,alpha,beta,rs_a,rs_b,cs_a,cs_b,rs_c_ref,cs_c_ref,i,j,k); \ \ - if ( post_op != NULL ) \ - { \ - dim_t ele_i = 0; \ - for ( dim_t op_id = 0; op_id < post_op->seq_length; ++op_id ) \ - { \ - if ( post_op->seq_vector[op_id] == BIAS ) \ - { \ - temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ - } \ - else if ( post_op->seq_vector[op_id] == ELTWISE ) \ - { \ - if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ - PRELU ) /* PReLU*/ \ - { \ - temp_accum = ( temp_accum > 0 ) ? \ - temp_accum : \ - ( temp_accum * \ - *( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \ - ele_i += 1; \ - } \ - else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ - GELU_TANH ) /* TANH GeLU*/ \ - { \ - temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\ - ele_i += 1; \ - } \ - else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ - GELU_ERF ) /* ERF GeLU*/ \ - { \ - temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\ - ele_i += 1; \ - } \ - else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ - RELU ) /* ReLU*/ \ - { \ - temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ - ele_i += 1; \ - } \ - else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ - CLIP ) /* CLIP*/ \ - { \ - temp_accum = \ - min \ - ( \ - max \ - ( \ - temp_accum, \ - *( ( ACCUM_type* ) \ - ( post_op->eltwise + ele_i )->algo.alpha ) \ - ), \ - *( ( ACCUM_type* ) \ - ( post_op->eltwise + ele_i )->algo.beta) \ - ); \ - ele_i += 1; \ - } \ - else \ - {} \ - } \ - else if ( post_op->seq_vector[op_id] == SCALE ) \ - { \ - temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ - (temp_accum, post_op, j); \ - } \ - else \ - {} \ - } \ - } \ - /* Need to convert to downscaled type if required.*/ \ - mat_mul_get_output_type_val ## ACCUM_type ## C_type \ - ( \ - &out_temp_accum, &temp_accum \ - ); \ + if ( post_op != NULL ) \ + { \ + dim_t ele_i = 0; \ + for ( dim_t op_id = 0; op_id < post_op->seq_length; ++op_id ) \ + { \ + if ( post_op->seq_vector[op_id] == BIAS ) \ + { \ + temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ + } \ + else if ( post_op->seq_vector[op_id] == ELTWISE ) \ + { \ + if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + PRELU ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * \ + *( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_TANH ) /* TANH GeLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_ERF ) /* ERF GeLU*/ \ + { \ + temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + RELU ) /* ReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + CLIP ) /* CLIP*/ \ + { \ + temp_accum = \ + min \ + ( \ + max \ + ( \ + temp_accum, \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.alpha ) \ + ), \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.beta) \ + ); \ + ele_i += 1; \ + } \ + else \ + {} \ + } \ + else if ( post_op->seq_vector[op_id] == SCALE ) \ + { \ + temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ + (temp_accum, post_op, j); \ + } \ + else \ + {} \ + } \ + } \ + /* Need to convert to downscaled type if required.*/ \ + mat_mul_get_output_type_val ## ACCUM_type ## C_type \ + ( \ + &out_temp_accum, &temp_accum \ + ); \ \ - if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ - { \ - if ( fout ) \ - { \ - fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ - " lda: %ld, ldb: %ld, ldc: %ld\n", \ - XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ - fflush( fout ); \ - } \ - printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ - goto cleanup_acc; \ - } \ - } \ - } \ + if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ + { \ + float comp_float, ref_float; \ + GEN_FUNC_NAME(C_type,_to_float)(*( c + ( rs_c * i ) + ( cs_c * j ) ), &comp_float); \ + GEN_FUNC_NAME(C_type,_to_float)(out_temp_accum, &ref_float); \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ + " lda: %ld, ldb: %ld, ldc: %ld, computed:%f, ref:%f, diff:%f\n", \ + XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc, comp_float, \ + ref_float, comp_float - ref_float); \ + fflush( fout ); \ + } \ + printf("failure, m: %ld, n: %ld, k: %ld, computed:%f, ref:%f, diff:%f\n", i, j, k, \ + comp_float, ref_float, comp_float-ref_float); \ + goto cleanup_acc; \ + } \ + } \ + } \ cleanup_acc: \ - return; \ + return; \ } \ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16) @@ -733,8 +884,7 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8) -/* Only supports bias followed by RELU and vice versa for now.*/ \ -#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,DSCALE_type,BLAS_SFX) \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_DSCALE_type,C_type,DSCALE_type,BLAS_SFX) \ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ ( \ dim_t m, \ @@ -742,280 +892,298 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ char* post_ops_str \ ) \ { \ - aocl_post_op* post_ops = NULL; \ - post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ + aocl_post_op* post_ops = NULL; \ + post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ \ - if ( ( post_ops == NULL ) && ( global_dscale_out == 'n' ) ) \ - { \ - return NULL; \ - } \ + if ( ( post_ops == NULL ) && ( global_dscale_out == 'n' ) ) \ + { \ + return NULL; \ + } \ \ - /* Only supporting 5 post ops at max for now.*/ \ - dim_t max_post_ops_seq_length = 5; \ - post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ - malloc \ - ( \ - max_post_ops_seq_length * \ - sizeof( AOCL_POST_OP_TYPE ) \ - ); \ + /* Only supporting 5 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 5; \ + post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ + malloc \ + ( \ + max_post_ops_seq_length * \ + sizeof( AOCL_POST_OP_TYPE ) \ + ); \ \ - if ( post_ops->seq_vector == NULL ) \ - { \ - free( post_ops ); \ - return NULL; \ - } \ + if ( post_ops->seq_vector == NULL ) \ + { \ + free( post_ops ); \ + return NULL; \ + } \ \ - /* Parse post ops list.*/ \ - dim_t cur_op_index = 0; \ - /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ - post_ops->eltwise = NULL; \ - post_ops->bias.bias = NULL; \ - post_ops->sum.scale_factor = NULL; \ - if ( post_ops_str != NULL ) \ - { \ - char* ops_tok = strtok(post_ops_str, ", " ); \ - bool is_relu = FALSE; \ - bool is_param_relu = FALSE; \ - bool is_gelu_tanh = FALSE; \ - bool is_gelu_erf = FALSE; \ - bool is_clip = FALSE; \ - dim_t activator_idx = 0; \ - dim_t clip_idx = 0; \ + /* Parse post ops list.*/ \ + dim_t cur_op_index = 0; \ + /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ + post_ops->eltwise = NULL; \ + post_ops->bias.bias = NULL; \ + post_ops->sum.scale_factor = NULL; \ + post_ops->sum.buff = NULL; \ + post_ops->sum.zero_point = NULL; \ + if ( post_ops_str != NULL ) \ + { \ + char* ops_tok = strtok(post_ops_str, ", " ); \ + bool is_relu = FALSE; \ + bool is_param_relu = FALSE; \ + bool is_gelu_tanh = FALSE; \ + bool is_gelu_erf = FALSE; \ + bool is_clip = FALSE; \ + dim_t activator_idx = 0; \ + dim_t clip_idx = 0; \ \ - /* Ensure only one activator is used as an eltwise post-op.*/ \ - bool is_activator_set = FALSE; \ - num_eltwise = 0; \ - while ( ops_tok ) \ - { \ - if ( strcmp( ops_tok, "bias") == 0 ) \ - { \ - post_ops->seq_vector[cur_op_index] = BIAS; \ - cur_op_index++; \ - } \ - else if ( ( strcmp( ops_tok, "relu") == 0 ) && \ - ( is_activator_set == FALSE ) ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_relu = TRUE; \ - is_activator_set = TRUE; \ - num_eltwise += 1; \ - activator_idx = cur_op_index; \ - cur_op_index++; \ - } \ - else if ( ( strcmp( ops_tok, "prelu") == 0 ) && \ - ( is_activator_set == FALSE ) ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_param_relu = TRUE; \ - is_activator_set = TRUE; \ - num_eltwise += 1; \ - activator_idx = cur_op_index; \ - cur_op_index++; \ - } \ - else if ( ( strcmp( ops_tok, "gelu_tanh") == 0 ) && \ - ( is_activator_set == FALSE ) ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_gelu_tanh = TRUE; \ - is_activator_set = TRUE; \ - num_eltwise += 1; \ - activator_idx = cur_op_index; \ - cur_op_index++; \ - } \ - else if ( ( strcmp( ops_tok, "gelu_erf") == 0 ) && \ - ( is_activator_set == FALSE ) ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_gelu_erf = TRUE; \ - is_activator_set = TRUE; \ - num_eltwise += 1; \ - activator_idx = cur_op_index; \ - cur_op_index++; \ - } \ - else if ( strcmp( ops_tok, "clip") == 0 ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_clip = TRUE; \ - num_eltwise += 1; \ - clip_idx = cur_op_index; \ - cur_op_index++; \ - } \ - ops_tok = strtok( NULL, ", " ); \ - } \ + /* Ensure only one activator is used as an eltwise post-op.*/ \ + bool is_activator_set = FALSE; \ + num_eltwise = 0; \ + while ( ops_tok ) \ + { \ + if ( strcmp( ops_tok, "bias" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "relu" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "prelu" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_param_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_tanh" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_tanh = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_erf" ) == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_erf = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "clip" ) == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_clip = TRUE; \ + num_eltwise += 1; \ + clip_idx = cur_op_index; \ + cur_op_index++; \ + } \ + ops_tok = strtok( NULL, ", " ); \ + } \ \ - /* Allocate bias buffer, return early if alloc fails.*/ \ - post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ - if ( post_ops->bias.bias == NULL ) \ - { \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ - } \ - GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ + if ( post_ops->bias.bias == NULL ) \ + { \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ \ - post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ - if ( post_ops->eltwise == NULL ) \ - { \ - free( post_ops->bias.bias ); \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ - } \ + post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ + if ( post_ops->eltwise == NULL ) \ + { \ + free( post_ops->bias.bias ); \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ \ - if ( num_eltwise > 0 ) \ - { \ - if ( num_eltwise > 1 ) \ - { \ - if ( activator_idx < clip_idx ) \ - { \ - activator_idx = 0; \ - clip_idx = 1; \ - } \ - else \ - { \ - activator_idx = 1; \ - clip_idx = 0; \ - } \ - } \ - else \ - { \ - activator_idx = 0; \ - clip_idx = 0; \ - } \ - } \ - /* Only one of relu,prelu,gelu_tanh,gelu_erf allowed as an activator.*/ \ - if ( is_relu == TRUE ) \ - { \ - ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ - ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \ - } \ - else if ( is_param_relu == TRUE ) \ - { \ - ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ - ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ - *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \ - ( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \ - } \ - else if ( is_gelu_tanh == TRUE ) \ - { \ - ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ - ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \ - } \ - else if ( is_gelu_erf == TRUE ) \ - { \ - ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ - ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ - ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \ - } \ - if ( is_clip == TRUE ) \ - { \ - ( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \ - ( post_ops->eltwise + clip_idx )->scale_factor = NULL; \ - ( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ - ( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \ - *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \ - *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 3 ); \ - ( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \ - } \ - } \ + if ( num_eltwise > 0 ) \ + { \ + if ( num_eltwise > 1 ) \ + { \ + if ( activator_idx < clip_idx ) \ + { \ + activator_idx = 0; \ + clip_idx = 1; \ + } \ + else \ + { \ + activator_idx = 1; \ + clip_idx = 0; \ + } \ + } \ + else \ + { \ + activator_idx = 0; \ + clip_idx = 0; \ + } \ + } \ + /* Only one of relu,prelu,gelu_tanh,gelu_erf allowed as an activator.*/ \ + if ( is_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \ + } \ + else if ( is_param_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \ + } \ + else if ( is_gelu_tanh == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \ + } \ + else if ( is_gelu_erf == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \ + } \ + if ( is_clip == TRUE ) \ + { \ + ( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + clip_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + ( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \ + *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 23 ); \ + ( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \ + } \ + } \ \ - if ( global_dscale_out == 'y' ) \ - { \ - post_ops->seq_vector[cur_op_index] = SCALE; \ - cur_op_index++; \ + if ( global_dscale_out == 'y' ) \ + { \ + post_ops->seq_vector[cur_op_index] = SCALE; \ + cur_op_index++; \ \ - post_ops->sum.is_power_of_2 = FALSE; \ - post_ops->sum.scale_factor = NULL; \ - post_ops->sum.buff = NULL; \ - post_ops->sum.zero_point = NULL; \ - if ( global_dscale_out == 'y' ) \ - { \ - /* Allocate scale buffer, return early if alloc fails.*/ \ - post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ - if ( post_ops->sum.scale_factor == NULL ) \ - { \ - free ( post_ops->eltwise ); \ - free ( post_ops->bias.bias ); \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ - } \ - /* Fill scale factor.*/ \ - DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )post_ops->sum.scale_factor; \ - for ( dim_t i = 0; i < n; ++i ) \ - { \ - temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ - } \ - } \ - } \ + post_ops->sum.is_power_of_2 = FALSE; \ + if ( global_dscale_out == 'y' ) \ + { \ + /* Allocate scale buffer, return early if alloc fails.*/ \ + post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ + post_ops->sum.zero_point = malloc( n * sizeof( C_DSCALE_type ) ); \ + if ( ( post_ops->sum.scale_factor == NULL ) || \ + ( post_ops->sum.zero_point == NULL ) ) \ + { \ + free ( post_ops->eltwise ); \ + free ( post_ops->bias.bias ); \ + free( post_ops->seq_vector ); \ + if ( post_ops->sum.zero_point != NULL ) \ + { \ + free( post_ops->sum.zero_point ); \ + } \ + if ( post_ops->sum.scale_factor != NULL ) \ + { \ + free( post_ops->sum.scale_factor ); \ + } \ + free( post_ops ); \ + return NULL; \ + } \ + /* Fill scale factor and zero points.*/ \ + DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )post_ops->sum.scale_factor; \ + C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )post_ops->sum.zero_point; \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ + temp_dzero_point_ptr[i] = (C_DSCALE_type)( i % 126 ); \ + } \ + } \ + } \ \ - post_ops->seq_length = cur_op_index; \ + post_ops->seq_length = cur_op_index; \ \ - return post_ops; \ + return post_ops; \ } \ -GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,u8s8s16os16) -GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,u8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(float,float,bf16bf16f32of32) -GEN_MAT_MUL_POST_OPS_CREATOR(float,float,f32f32f32of32) -GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,s8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,s8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bf16bf16f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,f32f32f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,s8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,s8s8s16os16) void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) { - if ( post_ops == NULL ) - { - return; - } - - if ( post_ops->eltwise != NULL ) - { - for ( dim_t i = 0; i < num_eltwise; ++i ) - { - if ( ( post_ops->eltwise + i )->algo.alpha != NULL ) - { - free( ( post_ops->eltwise + i )->algo.alpha ); - } - if ( ( post_ops->eltwise + i )->algo.beta != NULL ) - { - free( ( post_ops->eltwise + i )->algo.beta ); - } - } - free( post_ops->eltwise ); - } - if ( post_ops->sum.scale_factor != NULL ) - { - free( post_ops->sum.scale_factor ); - } - if ( post_ops->bias.bias != NULL ) - { - free( post_ops->bias.bias ); - } - if( post_ops->seq_vector != NULL ) - { - free( post_ops->seq_vector ); - } - - free( post_ops ); + if ( post_ops == NULL ) + { + return; + } + + if ( post_ops->eltwise != NULL ) + { + for ( dim_t i = 0; i < num_eltwise; ++i ) + { + if ( ( post_ops->eltwise + i )->algo.alpha != NULL ) + { + free( ( post_ops->eltwise + i )->algo.alpha ); + } + if ( ( post_ops->eltwise + i )->algo.beta != NULL ) + { + free( ( post_ops->eltwise + i )->algo.beta ); + } + } + free( post_ops->eltwise ); + } + if ( post_ops->sum.scale_factor != NULL ) + { + free( post_ops->sum.scale_factor ); + } + if ( post_ops->sum.zero_point != NULL ) + { + free( post_ops->sum.zero_point ); + } + if ( post_ops->bias.bias != NULL ) + { + free( post_ops->bias.bias ); + } + if( post_ops->seq_vector != NULL ) + { + free( post_ops->seq_vector ); + } + + free( post_ops ); } -#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX,REORDER_SFX) \ +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX) \ void mat_mul_bench_main_ ## BLAS_SFX \ ( \ FILE* fin, \ FILE* fout, \ char stor_order, \ - char op_t, \ + char transa, \ + char transb, \ + char op_a, \ + char op_b, \ int32_t m, \ int32_t n, \ int32_t k, \ @@ -1025,591 +1193,500 @@ void mat_mul_bench_main_ ## BLAS_SFX \ char* post_ops_str \ ) \ { \ - if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ - { \ - printf("The op_t ( 2nd arg in input.txt) is not valid\n"); \ - return; \ - } \ - \ - int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 100 )); \ - if ( global_n_repeat > 0 ) \ - { \ - n_repeats = global_n_repeat; \ - } \ - \ - /* Get 64 byte aligned memory.*/ \ - A_type* a = ( A_type* ) bli_malloc_user( sizeof( A_type ) * m * k ); \ - \ - B_type* b = ( B_type* ) bli_malloc_user( sizeof( B_type ) * n * k ); \ - \ - C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ - \ - C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ - \ - GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ - GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ - \ - if ( bench_mode == 'a' ) \ - { \ - GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ - GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( m * n ) ); \ - } \ - \ - C_type alpha; \ - C_type beta; \ - if ( bench_mode == 'p' ) \ - { \ - alpha = 1; \ - beta = 0; \ - } \ - else if ( bench_mode == 'a' ) \ - { \ - alpha = 2; \ - beta = 9; \ - } \ - \ - aocl_post_op* post_op = NULL; \ - if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ - { \ - post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str ); \ - if ( post_op == NULL ) \ - { \ - printf(" post op struct allocation failure, returning.\n"); \ - return; \ - } \ - } \ - \ - if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ - { \ - /* No reordering of B.*/ \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - stor_order, op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - } \ - else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ - { \ - /* Reorder B.*/ \ - siz_t b_reorder_buf_siz_req = \ - GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( 'B', k, n ); \ - \ - B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ - GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ - \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - stor_order, op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b_reorder, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - \ - bli_free_user( b_reorder ); \ - } \ - \ - if ( bench_mode == 'a' ) \ - { \ - printf("Running accuracy check.\n"); \ - GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ - ( \ - fout, stor_order, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - c_ref, stride_c, \ - post_op \ - ); \ - } \ - \ - lpgemm_destroy_post_ops_struct( post_op ); \ - \ - if ( a != NULL ) \ - { \ - bli_free_user( a ); \ - } \ - if ( b != NULL ) \ - { \ - bli_free_user( b ); \ - } \ - if ( c != NULL ) \ - { \ - bli_free_user( c ); \ - } \ - if ( c_ref != NULL ) \ - { \ - bli_free_user( c_ref ); \ - } \ -} \ - -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16,u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8,u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32,f32f32f32of32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,s8s8s32os32,s8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,s8s8s32os8,s8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,s8s8s16os16,s8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,s8s8s16os8,s8s8s16os16) - -#define GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(C_type, BLAS_SFX) \ -void mat_mul_bench_main_ ## BLAS_SFX \ - ( \ - FILE* fin, \ - FILE* fout, \ - char stor_order, \ - char op_t, \ - int32_t m, \ - int32_t n, \ - int32_t k, \ - int32_t stride_a, \ - int32_t stride_b, \ - int32_t stride_c, \ - char* post_ops_str \ - ) \ -{ \ - if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ - { \ - printf("The op_t ( 2nd arg in input.txt) is not valid\n");\ - return; \ - } \ - \ - int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 )); \ - if ( global_n_repeat > 0 ) \ - { \ - n_repeats = global_n_repeat; \ - } \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ \ - /* Get 64 byte aligned memory.*/ \ - bfloat16* a = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * m * k ); \ - float *a_float = bli_malloc_user( m * k * sizeof( float )); \ - for ( int32_t i = 0; i < m*k; ++i ) \ + int32_t size_A = 0; \ + int32_t size_B = 0; \ + int32_t size_C = 0; \ + if( ( stor_order == 'r' ) || ( stor_order == 'R' ) ) \ + { \ + size_A = ( ( transa == 'n' ) || ( transa == 'N' ) ) ? m * stride_a : k * stride_a; \ + size_B = ( ( transb == 'n' ) || ( transb == 'N' ) ) ? k * stride_b : n * stride_b; \ + size_C = m * stride_c; \ + } \ + else \ { \ - a_float[i] = ( float ) ( i % 5 ); \ + size_A = ( ( transa == 'n' ) || ( transa == 'N' ) ) ? k * stride_a : m * stride_a; \ + size_B = ( ( transb == 'n' ) || ( transb == 'N' ) ) ? n * stride_b : k * stride_b; \ + size_C = n * stride_c; \ } \ - convert_float_arr_to_bf16( a_float, a, m * k ); \ + A_type* a = ( A_type* ) lpgemm_malloc( sizeof( A_type ) * size_A ); \ + GEN_FUNC_NAME(fill_array_,A_type)(a, size_A ); \ \ - bfloat16* b = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * n * k ); \ - float *b_float = bli_malloc_user( k * n * sizeof( float )); \ - for ( int32_t i = 0; i < k*n; ++i ) \ - { \ - b_float[i] = ( float ) ( i % 5 );\ - } \ - convert_float_arr_to_bf16( b_float, b, k * n ); \ + B_type* b = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \ + GEN_FUNC_NAME(fill_array_,B_type)(b, size_B ); \ \ - C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + C_type* c = ( C_type* ) lpgemm_malloc( sizeof( C_type ) * size_C ); \ \ - C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + C_type* c_ref = ( C_type* ) lpgemm_malloc( sizeof( C_type ) * size_C ); \ \ - if ( bench_mode == 'a' ) \ - { \ - GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ - GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( m * n ) ); \ - } \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( c, ( size_C ) ); \ + GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( size_C ) ); \ + } \ + else \ + { \ + memset( ( void* ) c, 0, sizeof( C_type ) * size_C ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * size_C ); \ + } \ \ - float alpha; \ - float beta; \ - if ( bench_mode == 'p' ) \ - { \ - alpha = 1; \ - beta = 0; \ - } \ - else if ( bench_mode == 'a' ) \ - { \ - alpha = 2; \ - beta = 9; \ - } \ + Sum_type alpha = 0; \ + Sum_type beta = 0; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + n_repeats = 1; \ + alpha = 2; \ + beta = 9; \ + } \ \ - aocl_post_op* post_op = NULL; \ - if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ - { \ - post_op = lpgemm_create_post_ops_struct_bf16bf16f32of32( m, n, post_ops_str ); \ - if ( post_op == NULL ) \ - { \ - printf(" post op struct allocation failure, returning.\n"); \ - return; \ - } \ - } \ + aocl_post_op* post_op = NULL; \ + if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ + { \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ \ - if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ - { \ - /* No reordering of B.*/ \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - stor_order, op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - } \ - else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ - { \ - /* Reorder B.*/ \ - siz_t b_reorder_buf_siz_req = \ - aocl_get_reorder_buf_size_bf16bf16f32of32( 'B', k, n ); \ + if ( ( op_b == 'p' ) || ( op_b == 'P' ) || ( op_b == 'n' ) || ( op_b == 'N' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, transa, transb, op_a, op_b, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + else if ( ( op_b == 'r' ) || ( op_b == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order, transb, 'B', k, n ); \ \ - bfloat16* b_reorder = ( bfloat16* ) bli_malloc_user( b_reorder_buf_siz_req ); \ - aocl_reorder_bf16bf16f32of32( 'B', b, b_reorder, k, n, stride_b ); \ + B_type* b_reorder = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( stor_order, transb, 'B', b, b_reorder, k, n, stride_b ); \ \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - stor_order, op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b_reorder, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - } \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, transa, transb, op_a, op_b, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ \ - if ( bench_mode == 'a' ) \ - { \ - printf(" Running accuracy check.\n"); \ - GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ - ( \ - fout, stor_order, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - c_ref, stride_c, \ - post_op \ - ); \ - } \ + if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, stor_order, transa, transb, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c, \ + post_op \ + ); \ + } \ \ - lpgemm_destroy_post_ops_struct( post_op ); \ + lpgemm_destroy_post_ops_struct( post_op ); \ \ - if ( a != NULL ) \ - { \ - bli_free_user( a ); \ - } \ - if ( b != NULL ) \ - { \ - bli_free_user( b ); \ - } \ - if ( a_float != NULL ) \ - { \ - bli_free_user( a_float ); \ - } \ - if ( b_float != NULL ) \ - { \ - bli_free_user( b_float ); \ - } \ - if ( c != NULL ) \ - { \ - bli_free_user( c ); \ - } \ - if ( c_ref != NULL ) \ - { \ - bli_free_user( c_ref ); \ - } \ + lpgemm_free( a ); \ + lpgemm_free( b ); \ + lpgemm_free( c ); \ + lpgemm_free( c_ref ); \ } \ -GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(float,bf16bf16f32of32) -GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(bfloat16,bf16bf16f32obf16) - +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16,s8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8,s8s8s16os16) int main( int argc, char** argv ) { - FILE* fin = NULL; - if ( argc < 5 ) - { - printf - ( - "Usage: ./bench_lpgemm -i input.txt -m mode < -n 100 -o op1,op2 >\n" \ - "--Mode is either a or p.\n" \ - "\ta is used for accuracy testing.\n" \ - "\tp is used for performance benchmarking.\n" \ - "--n_repeats can be set optionally using -n arg.\n" \ - "--Post ops can be executed optionaly by providing a coma separated\n" \ - " list of post-ops after -o arg. Following post-ops are supported:\n" \ - " 1. bias\n" \ - " 2. 4 activators\n" \ - " a. relu\n" \ - " b. prelu\n" \ - " c. gelu_tanh\n" \ - " d. gelu_erf\n" \ - " 3.clip\n" \ - " Atleast one post-op needs to be specified if the -o arg is used.\n" \ - " eg: -o gelu_tanh; -o bias,relu ; -o clip,prelu,bias.\n" \ - " It is to be noted only one activator can be used at a time.\n" \ - " If more than one activator is used, only the first activator is\n" \ - " applied and the other activators are ignored.\n" \ - "--Downscaled version of an API is enabled by using -d arg.\n" \ - " Downscaled api's are used to enable quantization workflows.\n" \ - " Following downscaled api's are supported:\n" \ - " 1. u8s8s32os32 -d = u8s8s32os8.\n" \ - " 2. u8s8s16os16 -d = u8s8s16os8.\n" \ - " 3. bf16bf16f32obf32 -d = bf16bf16f32obf16.\n" \ - " 4. s8s8s32os32 -d = s8s8s32os8.\n" \ - " 5. s8s8s16os16 -d = s8s8s16os8.\n" \ - ); - exit( 1 ); - } - - char* file_name = NULL; - char* post_ops_str = NULL; - char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. - - // Parse CLI arguments. - opterr = 0; - int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:o:d" ) ) != -1 ) - { - switch ( opt_val ) - { - case 'i': - file_name = optarg; - break; - case 'm': - bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; - break; - case 'n': - global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; - break; - case 'o': - post_ops_str = optarg; - break; - case 'd': - global_dscale_out = 'y'; - break; - default: - break; - } - } - - if ( post_ops_str != NULL ) - { - post_ops_str_dest = ( char* )malloc \ - ( ( strlen( post_ops_str) + 1 )* sizeof( char ) ); - strcpy( post_ops_str_dest, post_ops_str ); - } - - if ( bench_mode == 'p' ) - { - printf( "Running bench in performance benchmarking mode.\n" ); - } - else if ( bench_mode == 'a' ) - { - printf( "Running bench in accuracy/correctness testing mode.\n" ); - } - - if ( file_name == NULL ) - { - printf( " File name provided is invalid.\n" ); - exit( 1 ); - } - - fin = fopen( file_name, "r" ); - if (fin == NULL) - { - printf( "Error opening the file %s\n", argv[1] ); - exit( 1 ); - } - - FILE* fout = NULL; - - fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); - - char op_type_char; - char op_t; - char stor_order; - int32_t m, n, k; - int32_t stride_a, stride_b, stride_c; - - const dim_t len_list_omp_cores_for_testing = 2; - const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; - - dim_t core_index = 0; - bool can_run = TRUE; - while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) - { - if ( bench_mode == 'p' ) - { - can_run = FALSE; - } - else if ( bench_mode == 'a' ) - { - // For accuracy testing, we test accuracy using multiple different - // number of cores. This helps uncover any bugs related to over - // subscription or varying thread factorizations. - // Set current number of cores. + FILE* fin = NULL; + if ( argc < 5 ) + { + printf + ( + "Usage: ./bench_lpgemm -i input.txt -m mode < -n 100 -o op1,op2 >\n" \ + "--Mode is either a or p.\n" \ + "\ta is used for accuracy testing.\n" \ + "\tp is used for performance benchmarking.\n" \ + "--n_repeats can be set optionally using -n arg.\n" \ + "--Post ops can be executed optionaly by providing a coma separated\n" \ + " list of post-ops after -o arg. Following post-ops are supported:\n" \ + " 1. bias\n" \ + " 2. 4 activators\n" \ + " a. relu\n" \ + " b. prelu\n" \ + " c. gelu_tanh\n" \ + " d. gelu_erf\n" \ + " 3.clip\n" \ + " Atleast one post-op needs to be specified if the -o arg is used.\n" \ + " eg: -o gelu_tanh; -o bias,relu ; -o clip,prelu,bias.\n" \ + " It is to be noted only one activator can be used at a time.\n" \ + " If more than one activator is used, only the first activator is\n" \ + " applied and the other activators are ignored.\n" \ + "--Downscaled version of an API is enabled by using -d arg followed\n" \ + " by the datatype that needs to be downscaled to" + " Downscaled api's are used to enable quantization workflows.\n" \ + " Following downscaled api's are supported:\n" \ + " 1. u8s8s32os32 -d s8 = u8s8s32os8.\n" \ + " 2. u8s8s16os16 -d s8 = u8s8s16os8.\n" \ + " 3. u8s8s16os16 -d u8 = u8s8s16ou8.\n" \ + " 4. bf16bf16f32obf32 -d bf16 = bf16bf16f32obf16.\n" \ + " 5. s8s8s32os32 -d s8 = s8s8s32os8.\n" \ + " 6. s8s8s16os16 -d s8 = s8s8s16os8.\n" \ + " Example: ./bench_lpgemm -m a -n 2 -o bias,relu -d bf16 -i input.txt\n" \ + ); + exit( 1 ); + } + + char* file_name = NULL; + char post_ops_str[50]; + char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. + char dscale_type_str[10]; + + // Parse CLI arguments. + opterr = 0; + int opt_val; + while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + { + switch ( opt_val ) + { + case 'i': + file_name = optarg; + break; + case 'm': + bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + break; + default: + break; + } + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); + + char op_type_char; + char op_a, op_b; + char stor_order; + char transa, transb; + int32_t m, n, k; + int32_t stride_a, stride_b, stride_c; + + const dim_t len_list_omp_cores_for_testing = 2; + const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; + + dim_t core_index = 0; + bool can_run = TRUE; + while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) + { + if ( bench_mode == 'p' ) + { + can_run = FALSE; + } + else if ( bench_mode == 'a' ) + { + // For accuracy testing, we test accuracy using multiple different + // number of cores. This helps uncover any bugs related to over + // subscription or varying thread factorizations. + // Set current number of cores. #ifdef BLIS_ENABLE_OPENMP - omp_set_num_threads( list_omp_cores_for_testing[core_index] ); + omp_set_num_threads( list_omp_cores_for_testing[core_index] ); #endif - printf( "Accuracy test using %ld threads.\n", - list_omp_cores_for_testing[core_index] ); - - core_index++; - if ( core_index < len_list_omp_cores_for_testing ) - { - can_run = TRUE; - } - else - { - can_run = FALSE; - } - } - - // Input format: data_type stor_type pack/reorder m n k lda ldb ldc - while ( fscanf( fin, "%c %c %c %d %d %d %d %d %d\n", - &op_type_char, &stor_order, &op_t, &m, &n, &k, - &stride_a, &stride_b, &stride_c ) == 9 ) - { - stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || - ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? - stor_order : 'r'; - - if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) - { - if ( global_dscale_out == 'n' ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - } - else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else if ((op_type_char == 's') || (op_type_char == 'S')) - { - if ( global_dscale_out == 'n' ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - } - else if ((op_type_char == 'b') || (op_type_char == 'B')) - { - if ( global_dscale_out == 'n' ) - { - GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - } - else if ( ( op_type_char == 'u' ) || ( op_type_char == 'U' ) ) - { - if ( global_dscale_out == 'n' ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - } - else if ( ( op_type_char == 'v' ) || ( op_type_char == 'V' ) ) - { - if ( global_dscale_out == 'n' ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else - { - GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8) - ( - fin, fout, stor_order, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - } - if ( post_ops_str != NULL ) - { - strcpy( post_ops_str_dest, post_ops_str ); - } - } - } - - if ( post_ops_str_dest != NULL ) - { - free( post_ops_str_dest ); - } - if ( fin ) - { - fclose( fin ); - } - if ( fout ) - { - fclose( fout ); - } - return 0; + printf( "Accuracy test using %ld threads.\n", + list_omp_cores_for_testing[core_index] ); + + core_index++; + if ( core_index < len_list_omp_cores_for_testing ) + { + can_run = TRUE; + } + else + { + can_run = FALSE; + } + } + + // Input format: data_type stor_type pack/reorder m n k lda ldb ldc + while ( fscanf( fin, "%c %s %c %c %c %c %c %d %d %d %d %d %d %s\n", + &op_type_char, dscale_type_str, &stor_order, &transa, &transb, &op_a, &op_b, &m, &n, &k, + &stride_a, &stride_b, &stride_c, post_ops_str ) == 14 ) + { + stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || + ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? + stor_order : 'r'; + + if ( strcmp( post_ops_str, "none" ) != 0 ) + { + post_ops_str_dest = ( char* )malloc \ + ( ( strlen( post_ops_str) + 1 )* sizeof( char ) ); + strcpy( post_ops_str_dest, post_ops_str ); + } + + if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) + { + if ( ( strcmp( dscale_type_str, "S32" ) == 0 ) || + ( strcmp( dscale_type_str, "s32" ) == 0 ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || + ( strcmp( dscale_type_str, "s8" ) == 0 ) ) + { + global_dscale_out = 'y'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + printf("Downscale type not supported.\n"); + } + } + } + else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else if ((op_type_char == 's') || (op_type_char == 'S')) + { + if ( ( strcmp( dscale_type_str, "S16" ) == 0 ) || + ( strcmp( dscale_type_str, "s16" ) == 0 ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || + ( strcmp( dscale_type_str, "s8" ) == 0 ) ) + { + global_dscale_out = 'y'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else if ( ( strcmp( dscale_type_str, "U8" ) == 0 ) || + ( strcmp( dscale_type_str, "u8" ) == 0 ) ) + { + global_dscale_out = 'y'; + DSCALE_CLIP_MIN = 0; + DSCALE_CLIP_MAX = +255; + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16ou8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + printf("Downscale type not supported.\n"); + } + } + } + else if ((op_type_char == 'b') || (op_type_char == 'B')) + { + if ( ( strcmp( dscale_type_str, "F32" ) == 0 ) || + ( strcmp( dscale_type_str, "f32" ) == 0 ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else if ( ( strcmp( dscale_type_str, "BF16" ) == 0 ) || + ( strcmp( dscale_type_str, "bf16" ) == 0 ) ) + { + global_dscale_out = 'y'; + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + printf("Downscale type not supported.\n"); + } + } + else if ( ( op_type_char == 'u' ) || ( op_type_char == 'U' ) ) + { + if ( ( strcmp( dscale_type_str, "S32" ) == 0 ) || + ( strcmp( dscale_type_str, "s32" ) == 0 ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || + ( strcmp( dscale_type_str, "s8" ) == 0 ) ) + { + global_dscale_out = 'y'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + printf("Downscale type not supported.\n"); + } + } + } + else if ( ( op_type_char == 'v' ) || ( op_type_char == 'V' ) ) + { + if ( ( strcmp( dscale_type_str, "S16" ) == 0 ) || + ( strcmp( dscale_type_str, "s16" ) == 0 ) ) + { + global_dscale_out = 'n'; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + if ( ( strcmp( dscale_type_str, "S8" ) == 0 ) || + ( strcmp( dscale_type_str, "s8" ) == 0 ) ) + { + global_dscale_out = 'y'; + DSCALE_CLIP_MIN = -128; + DSCALE_CLIP_MAX = +127; + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8) + ( + fin, fout, stor_order, transa, transb, op_a, op_b, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + printf("Downscale type not supported.\n"); + } + } + } + if ( strcmp( post_ops_str, "none" ) != 0 ) + { + strcpy( post_ops_str_dest, post_ops_str ); + } + } + } + + if ( post_ops_str_dest != NULL ) + { + free( post_ops_str_dest ); + } + if ( fin ) + { + fclose( fin ); + } + if ( fout ) + { + fclose( fout ); + } + return 0; } diff --git a/bench/bench_aocl_gemm/bench_lpgemm_utils.c b/bench/bench_aocl_gemm/bench_lpgemm_utils.c index dbbdce6703..8ce8104df5 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm_utils.c +++ b/bench/bench_aocl_gemm/bench_lpgemm_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -157,7 +157,7 @@ void softmax_bench_driver_ ## SOFTMAX_SFX \ GEN_SOFTMAX_BENCH_DRV_FN(float,softmax_f32) -inline float gelu_tanh_f32 +static inline float gelu_tanh_f32 ( float temp_accum ) @@ -168,7 +168,7 @@ inline float gelu_tanh_f32 return temp_accum; }\ -inline float gelu_erf_f32 +static inline float gelu_erf_f32 ( float temp_accum ) @@ -261,10 +261,11 @@ void gelu_bench_main_ ## GELU_SFX \ n_repeats = global_n_repeat; \ } \ \ - V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + err_t bli_errors = BLIS_SUCCESS; \ + V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx, &bli_errors ); \ GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx ) ); \ \ - V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx, &bli_errors ); \ GEN_FUNC_NAME(fill_array_,V_type)( ref_x, ( n * incx ) ); \ \ GEN_FUNC_NAME(gelu_bench_driver_,GELU_SFX)(n_repeats,n,x,incx); \ @@ -292,10 +293,11 @@ void softmax_bench_main_ ## SOFTMAX_SFX \ n_repeats = global_n_repeat; \ } \ \ - V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + err_t bli_errors = BLIS_SUCCESS; \ + V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx, &bli_errors ); \ GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx ) ); \ \ - V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx, &bli_errors ); \ GEN_FUNC_NAME(fill_array_,V_type)( ref_x, ( n * incx ) ); \ \ GEN_FUNC_NAME(softmax_bench_driver_,SOFTMAX_SFX)(n_repeats,n,x,incx); \ diff --git a/bench/bench_copyv.c b/bench/bench_copyv.c index 7be38907ed..1e7f20e647 100644 --- a/bench/bench_copyv.c +++ b/bench/bench_copyv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_dotv.c b/bench/bench_dotv.c index 0d39594f72..9ca0cd386d 100644 --- a/bench/bench_dotv.c +++ b/bench/bench_dotv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c index d9dc523e92..454b8b0bc0 100755 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_gemm_pack_compute.c b/bench/bench_gemm_pack_compute.c new file mode 100755 index 0000000000..30236ee859 --- /dev/null +++ b/bench/bench_gemm_pack_compute.c @@ -0,0 +1,996 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + + +// Benchmark application to process aocl logs generated by BLIS library. +#ifndef DT +#define DT BLIS_DOUBLE +#endif + +#ifndef IND +#define IND BLIS_NAT +#endif + +#ifndef N_REPEAT +//#define N_REPEAT 100 +#endif + + +#define AOCL_MATRIX_INITIALISATION +#define BUFFER_SIZE 256 + +/* For BLIS since logs are collected at BLAS interfaces + * we disable cblas interfaces for this benchmark application + */ + +#ifdef BLIS_ENABLE_CBLAS +// #define CBLAS +#endif + +// #define PRINT + +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta, alpha_one; + dim_t m, n, k; + dim_t p_inc = 0; // to keep track of number of inputs + num_t dt; + // ind_t ind; + char dt_ch; + int r, n_repeats; + trans_t transa; + trans_t transb; + + double dtime; + double dtime_save; + double gflops; + + int packA, packB; + + FILE* fin = NULL; + FILE* fout = NULL; + + n_repeats = N_REPEAT; // This macro will get from Makefile. + + dt = DT; + + if (argc < 3) + { + printf("Usage: ./test_gemm_pack_compute_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + if (argc > 3) + { + n_repeats = atoi(argv[3]); + } + + fprintf(fout, "Dt transa transb identifier m n k alphaR alphaI lda ldb betaR betaI ldc gflops\n"); + + // Following variables are needed for scanf to read inputs properly + // however they are not used in bench. + char api_name[BUFFER_SIZE]; // to store function name, line no present in logs + char dummy_buffer[BUFFER_SIZE]; + + // Variables extracted from the logs which are used by bench + char stor_scheme, transA_c, transB_c, packA_c, packB_c; + double alpha_r, beta_r, alpha_i, beta_i; + dim_t m_trans, n_trans; + inc_t lda, ldb, ldc; + + stor_scheme = 'C'; // By default set it to Column Major + + //{S, D, C, Z} transa, transb, packA, packB, m, n, k, alpha_real, + // alpha_imag, lda ldb, beta_real, beta_imag, ldc, + // + // number of threads, execution time, gflops ---> ignored by bench + while (fscanf(fin, "%s %c %c %c %c %c " INT_FS INT_FS INT_FS " %lf %lf " INT_FS INT_FS " %lf %lf " INT_FS"[^\n]", + api_name, &dt_ch, &transA_c, &transB_c, &packA_c, &packB_c, &m, &n, &k, &alpha_r, &alpha_i, + &lda, &ldb, &beta_r, &beta_i, &ldc) == 16) + { + // Discard any extra data on current line in the input file. + fgets(dummy_buffer, BUFFER_SIZE, fin ); + + // At BLAS level only column major order is supported. + stor_scheme = 'C'; + + if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; + else if (dt_ch == 'S' || dt_ch == 's') dt = BLIS_FLOAT; + else + { + printf("Invalid data type %c\n", dt_ch); + continue; + } + + if ( transA_c == 'n' || transA_c == 'N' ) transa = BLIS_NO_TRANSPOSE; + else if ( transA_c == 't' || transA_c == 'T' ) transa = BLIS_TRANSPOSE; + else if ( transA_c == 'c' || transA_c == 'C' ) transa = BLIS_CONJ_TRANSPOSE; + else + { + printf("Invalid option for transA \n"); + continue; + } + + if ( transB_c == 'n' || transB_c == 'N' ) transb = BLIS_NO_TRANSPOSE; + else if ( transB_c == 't' || transB_c == 'T' ) transb = BLIS_TRANSPOSE; + else if ( transB_c == 'c' || transB_c == 'C' ) transb = BLIS_CONJ_TRANSPOSE; + else + { + printf("Invalid option for transB \n"); + continue; + } + + if ( packA_c == 'p' || packA_c == 'P' ) packA = TRUE; + else if ( packA_c == 'u' || packA_c == 'U' ) packA = FALSE; + else + { + printf("Invalid option for packA \n"); + continue; + } + + if ( packB_c == 'p' || packB_c == 'P') packB = TRUE; + else if ( packB_c == 'u' || packB_c == 'U') packB = FALSE; + else + { + printf("Invalid option for packB \n"); + continue; + } + + bli_obj_create( dt, 1, 1, 0, 0, &alpha); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, 1, 1, 0, 0, &alpha_one); + + if( (stor_scheme == 'C') || (stor_scheme == 'c') ) + { + // leading dimension should be greater than number of rows + // if ((m > lda) || (k > ldb) || (m > ldc)) continue; + // Since this bench app is run on logs generated by AOCL trace logs + // - we have relaxed the checks on the input parameters. + + // if A is transpose - A(lda x m), lda >= max(1,k) + // if A is non-transpose - A (lda x k), lda >= max(1,m) + // if B is transpose - B (ldb x k), ldb >= max(1,n) + // if B is non-transpose - B (ldb x n), ldb >= max(1,k) + // C is ldc x n - ldc >= max(1, m) + //if(transa) lda = k; // We will end up overwriting lda + bli_set_dims_with_trans( transa, m, k, &m_trans, &n_trans); + bli_obj_create( dt, m_trans, n_trans, 1, lda, &a); + + //if(transb) ldb = n; // we will end up overwriting ldb, ldb >= n + bli_set_dims_with_trans( transb, k, n, &m_trans, &n_trans); + bli_obj_create( dt, m_trans, n_trans, 1, ldb, &b); + + bli_obj_create( dt, m, n, 1, ldc, &c); + bli_obj_create( dt, m, n, 1, ldc, &c_save ); + } + else if( (stor_scheme == 'r') || (stor_scheme == 'R') ) + { + //leading dimension should be greater than number of columns + //if ((k > lda) || (n > ldb) || (n > ldc)) continue; + // Since this bench app is run on logs generated by AOCL trace logs + // - we have relaxed the checks on the input parameters. + + // if A is transpose - A(k x lda), lda >= max(1,m) + // if A is non-transpose - A (m x lda), lda >= max(1,k) + // if B is transpose - B (n x ldb), ldb >= max(1,k) + // if B is non-transpose - B (k x ldb ), ldb >= max(1,n) + // C is m x ldc - ldc >= max(1, n) + + //if(transa) lda = m; // this will overwrite lda + bli_set_dims_with_trans(transa, m, k, &m_trans, &n_trans); + bli_obj_create( dt, m_trans, n_trans, lda, 1, &a); + + //if(transb) ldb = k; // this will overwrite ldb + bli_set_dims_with_trans(transb, k, n, &m_trans, &n_trans); + bli_obj_create( dt, m_trans, n_trans, ldb, 1, &b); + + bli_obj_create( dt, m, n, ldc, 1, &c); + bli_obj_create( dt, m, n, ldc, 1, &c_save ); + } + else + { + printf("Invalid storage scheme\n"); + continue; + } +#ifndef BLIS // Incase if we are using blis interface we don't have to check for col-storage. + #ifndef CBLAS + if( ( stor_scheme == 'R' ) || ( stor_scheme == 'r' ) ) + { + printf("BLAS APIs doesn't support row-storage: Enable CBLAS\n"); + continue; + } + #endif +#endif + +#ifdef AOCL_MATRIX_INITIALISATION + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); +#endif + bli_copym( &c, &c_save ); + + bli_obj_set_conjtrans( transa, &a); + bli_obj_set_conjtrans( transb, &b); + + bli_setsc( 1.0, 1.0, &alpha_one ); + bli_setsc( alpha_r, alpha_i, &alpha ); + bli_setsc( beta_r, beta_i, &beta ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); +#ifdef PRINT + bli_printm( "a", &a, "%4.6f", "" ); + bli_printm( "b", &b, "%4.6f", "" ); + bli_printm( "c", &c, "%4.6f", "" ); +#endif + +#ifdef BLIS + + printf( "BLAS Extension APIs don't have a BLIS interface." + "Enable CBLAS or BLAS interface!\n" ); + +#else + +#ifdef CBLAS + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + enum CBLAS_IDENTIFIER cblas_identifierA; + enum CBLAS_IDENTIFIER cblas_identifierB; + + size_t bufSizeA; + size_t bufSizeB; + + if ( ( stor_scheme == 'C' ) || ( stor_scheme == 'c' ) ) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + + if( bli_is_trans( transa ) ) + cblas_transa = CblasTrans; + else if( bli_is_conjtrans( transa ) ) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if( bli_is_trans( transb ) ) + cblas_transb = CblasTrans; + else if( bli_is_conjtrans( transb ) ) + cblas_transb = CblasConjTrans; + else + cblas_transb = CblasNoTrans; + + if ( packA ) + cblas_identifierA = CblasAMatrix; + + if ( packB ) + cblas_identifierB = CblasBMatrix; +#else + f77_char f77_transa; + f77_char f77_transb; + f77_char f77_identifierA; + f77_char f77_identifierB; + f77_int f77_bufSizeA; + f77_int f77_bufSizeB; + + f77_char f77_packed = 'P'; + f77_identifierA = 'A'; + f77_identifierB = 'B'; + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + + err_t err = BLIS_SUCCESS; + +#endif + if ( bli_is_float( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + + float* alphaonep = bli_obj_buffer( &alpha_one ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + +#ifdef CBLAS + float* aBuffer; + float* bBuffer; + + if ( packA && !packB ) + { + // Only A is pre-packed. + bufSizeA = cblas_sgemm_pack_get_size( CblasAMatrix, + mm, + nn, + kk ); + aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + + cblas_sgemm_pack( cblas_order, + CblasAMatrix, + cblas_transa, + mm, + nn, + kk, + *alphap, + ap, lda, + aBuffer ); + + dtime = bli_clock(); + + cblas_sgemm_compute( cblas_order, + CblasPacked, + cblas_transb, + mm, + nn, + kk, + aBuffer, lda, + bp, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + } + else if ( !packA && packB ) + { + // Only B is pre-packed. + bufSizeB = cblas_sgemm_pack_get_size( CblasBMatrix, + mm, + nn, + kk ); + bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, + CblasBMatrix, + cblas_transb, + mm, + nn, + kk, + *alphap, + bp, ldb, + bBuffer ); + + dtime = bli_clock(); + + cblas_sgemm_compute( cblas_order, + cblas_transa, + CblasPacked, + mm, + nn, + kk, + ap, lda, + bBuffer, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + + bli_free_user(bBuffer); + } + else if ( packA && packB ) + { + // Both A & B are pre-packed. + bufSizeA = cblas_sgemm_pack_get_size( CblasAMatrix, + mm, + nn, + kk ); + aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + + bufSizeB = cblas_sgemm_pack_get_size( CblasBMatrix, + mm, + nn, + kk ); + bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, + CblasAMatrix, + cblas_transa, + mm, + nn, + kk, + *alphap, + ap, lda, + aBuffer ); + + cblas_sgemm_pack( cblas_order, + CblasBMatrix, + cblas_transb, + mm, + nn, + kk, + *alphaonep, + bp, ldb, + bBuffer ); + + dtime = bli_clock(); + + cblas_sgemm_compute( cblas_order, + CblasPacked, + CblasPacked, + mm, + nn, + kk, + aBuffer, lda, + bBuffer, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + bli_free_user(bBuffer); + } + else + { + // Neither A nor B is pre-packed. + + dtime = bli_clock(); + + cblas_sgemm_compute( cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } +#else // -- BLAS API -- + float* aBuffer; + float* bBuffer; + + if ( packA && !packB ) + { + // Only A is pre-packed. + f77_bufSizeA = sgemm_pack_get_size_( &f77_identifierA, + &mm, + &nn, + &kk ); + aBuffer = (float*) bli_malloc_user( f77_bufSizeA, &err ); + + sgemm_pack_( &f77_identifierA, + &f77_transa, + &mm, + &nn, + &kk, + alphap, + ap, + (f77_int*)&lda, + aBuffer ); + + dtime = bli_clock(); + + sgemm_compute_( &f77_packed, + &f77_transb, + &mm, + &nn, + &kk, + aBuffer, (f77_int*)&lda, + bp, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user( aBuffer ); + } + else if ( !packA && packB ) + { + // Only B is pre-packed. + f77_bufSizeB = sgemm_pack_get_size_( &f77_identifierB, + &mm, + &nn, + &kk ); + bBuffer = (float*) bli_malloc_user( f77_bufSizeB, &err ); + + sgemm_pack_( &f77_identifierB, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + bp, + (f77_int*)&ldb, + bBuffer ); + + dtime = bli_clock(); + + sgemm_compute_( &f77_transa, + &f77_packed, + &mm, + &nn, + &kk, + ap, (f77_int*)&lda, + bBuffer, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user( bBuffer ); + } + else if ( packA && packB ) + { + // Both A & B are pre-packed. + f77_bufSizeB = sgemm_pack_get_size_( &f77_identifierB, + &mm, + &nn, + &kk ); + + bBuffer = (float*) bli_malloc_user( f77_bufSizeB, &err ); + + f77_bufSizeA = sgemm_pack_get_size_( &f77_identifierA, + &mm, + &nn, + &kk ); + + aBuffer = (float*) bli_malloc_user( f77_bufSizeA, &err ); + + sgemm_pack_( &f77_identifierA, + &f77_transa, + &mm, + &nn, + &kk, + alphap, + ap, + (f77_int*)&lda, + aBuffer ); + + sgemm_pack_( &f77_identifierB, + &f77_transb, + &mm, + &nn, + &kk, + alphaonep, + bp, + (f77_int*)&ldb, + bBuffer ); + + dtime = bli_clock(); + + sgemm_compute_( &f77_packed, + &f77_packed, + &mm, + &nn, + &kk, + aBuffer, (f77_int*)&lda, + bBuffer, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + bli_free_user(bBuffer); + } + else + { + // Neither A nor B is reordered. + + dtime = bli_clock(); + + sgemm_compute_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + ap, (f77_int*)&lda, + bp, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } +#endif + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &c ); + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + + double* alphap = bli_obj_buffer( &alpha ); + double* alphaonep = bli_obj_buffer( &alpha_one ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + +#ifdef CBLAS + double* aBuffer; + double* bBuffer; + + if ( packA && !packB ) + { + // Only A is pre-packed. + bufSizeA = cblas_dgemm_pack_get_size( CblasAMatrix, + mm, + nn, + kk ); + aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + + cblas_dgemm_pack( cblas_order, + CblasAMatrix, + cblas_transa, + mm, + nn, + kk, + *alphap, + ap, lda, + aBuffer ); + + dtime = bli_clock(); + + cblas_dgemm_compute( cblas_order, + CblasPacked, + cblas_transb, + mm, + nn, + kk, + aBuffer, lda, + bp, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + } + else if ( !packA && packB ) + { + // Only B is pre-packed. + bufSizeB = cblas_dgemm_pack_get_size( CblasBMatrix, + mm, + nn, + kk ); + + cblas_dgemm_pack( cblas_order, + CblasBMatrix, + cblas_transb, + mm, + nn, + kk, + *alphap, + bp, ldb, + bBuffer ); + + dtime = bli_clock(); + + cblas_dgemm_compute( cblas_order, + cblas_transa, + CblasPacked, + mm, + nn, + kk, + ap, lda, + bBuffer, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(bBuffer); + } + else if ( packA && packB ) + { + // Both A & B are pre-packed. + bufSizeA = cblas_dgemm_pack_get_size( CblasAMatrix, + mm, + nn, + kk ); + aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + + bufSizeB = cblas_dgemm_pack_get_size( CblasBMatrix, + mm, + nn, + kk ); + bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + + cblas_dgemm_pack( cblas_order, + CblasAMatrix, + cblas_transa, + mm, + nn, + kk, + *alphap, + ap, lda, + aBuffer ); + + cblas_dgemm_pack( cblas_order, + CblasBMatrix, + cblas_transb, + mm, + nn, + kk, + *alphap, + bp, ldb, + bBuffer ); + + dtime = bli_clock(); + + cblas_dgemm_compute( cblas_order, + CblasPacked, + CblasPacked, + mm, + nn, + kk, + aBuffer, lda, + bBuffer, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + bli_free_user(bBuffer); + } + else + { + // Neither A nor B is pre-packed. + + dtime = bli_clock(); + + cblas_dgemm_compute( cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + +#else // -- BLAS API -- + double* aBuffer; + double* bBuffer; + + if ( packA && !packB ) + { + // Only A is pre-packed. + f77_bufSizeA = dgemm_pack_get_size_( &f77_identifierA, + &mm, + &nn, + &kk ); + aBuffer = (double*) bli_malloc_user( f77_bufSizeA, &err ); + + dgemm_pack_( &f77_identifierA, + &f77_transa, + &mm, + &nn, + &kk, + alphap, + ap, + (f77_int*)&lda, + aBuffer ); + + dtime = bli_clock(); + + dgemm_compute_( &f77_packed, + &f77_transb, + &mm, + &nn, + &kk, + aBuffer, (f77_int*)&lda, + bp, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user( aBuffer ); + } + else if ( !packA && packB ) + { + // Only B is pre-packed. + f77_bufSizeB = dgemm_pack_get_size_( &f77_identifierB, + &mm, + &nn, + &kk ); + bBuffer = (double*) bli_malloc_user( f77_bufSizeB, &err ); + + dgemm_pack_( &f77_identifierB, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + bp, + (f77_int*)&ldb, + bBuffer ); + + dtime = bli_clock(); + + dgemm_compute_( &f77_transa, + &f77_packed, + &mm, + &nn, + &kk, + ap, (f77_int*)&lda, + bBuffer, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user( bBuffer ); + } + else if ( packA && packB ) + { + // Both A & B are pre-packed. + f77_bufSizeA = dgemm_pack_get_size_( &f77_identifierA, + &mm, + &nn, + &kk ); + aBuffer = (double*) bli_malloc_user( f77_bufSizeA, &err ); + + f77_bufSizeB = dgemm_pack_get_size_( &f77_identifierB, + &mm, + &nn, + &kk ); + bBuffer = (double*) bli_malloc_user( f77_bufSizeB, &err ); + + dgemm_pack_( &f77_identifierA, + &f77_transa, + &mm, + &nn, + &kk, + alphap, + ap, + (f77_int*)&lda, + aBuffer ); + + dgemm_pack_( &f77_identifierB, + &f77_transb, + &mm, + &nn, + &kk, + alphaonep, + bp, + (f77_int*)&ldb, + bBuffer ); + + dtime = bli_clock(); + + dgemm_compute_( &f77_packed, + &f77_packed, + &mm, + &nn, + &kk, + aBuffer, (f77_int*)&lda, + bBuffer, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + + bli_free_user(aBuffer); + bli_free_user(bBuffer); + } + else + { + // Neither A nor B is reordered. + + dtime = bli_clock(); + + dgemm_compute_( &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + ap, (f77_int*)&lda, + bp, (f77_int*)&ldb, + betap, + cp, (f77_int*)&ldc ); + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } +#endif + } +#endif + +#ifdef PRINT + bli_printm( "c compute", &c, "%4.6f", "" ); +#endif + } + + gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_%cgemm_%s", dt_ch, BLAS ); + + p_inc++; + printf("( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)m, + (unsigned long)n, + (unsigned long)k, gflops); + + fprintf (fout, "%c %c %c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld %6.3f\n", \ + dt_ch, transA_c, transB_c, packA_c, packB_c, m, n, k, alpha_r, alpha_i, lda, ldb, beta_r, beta_i, ldc, gflops); + + fflush(fout); + + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); + fclose(fin); + fclose(fout); + + return 0; +} diff --git a/bench/bench_gemmt.c b/bench/bench_gemmt.c index ad24593747..cd2e5bf9b8 100644 --- a/bench/bench_gemmt.c +++ b/bench/bench_gemmt.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: diff --git a/bench/bench_gemv.c b/bench/bench_gemv.c index 9f06bf8efb..dd77a0539c 100755 --- a/bench/bench_gemv.c +++ b/bench/bench_gemv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_ger.c b/bench/bench_ger.c index 2c8981a682..b4ee38a799 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_scalv.c b/bench/bench_scalv.c index b8cd6241c1..80b3762ea2 100644 --- a/bench/bench_scalv.c +++ b/bench/bench_scalv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_swapv.c b/bench/bench_swapv.c index 6f2c8fd90e..3040d7b582 100644 --- a/bench/bench_swapv.c +++ b/bench/bench_swapv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/bench_syrk.c b/bench/bench_syrk.c index b65db83aa5..5bcc20e060 100644 --- a/bench/bench_syrk.c +++ b/bench/bench_syrk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index 7014bd4753..87dd677a4d 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -3,8 +3,10 @@ BLIS An object-based framework for developing high-performance BLAS-like libraries. + Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/bench/bench_trsv.c b/bench/bench_trsv.c index 425f61f1d0..4714f813d4 100644 --- a/bench/bench_trsv.c +++ b/bench/bench_trsv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/bench/inputgemmpackcompute.txt b/bench/inputgemmpackcompute.txt new file mode 100644 index 0000000000..3afff8baf0 --- /dev/null +++ b/bench/inputgemmpackcompute.txt @@ -0,0 +1,92 @@ +sgemm_ S N N P U 1 1 1 1 0 1 1 1 0 1 +sgemm_ S N N P U 2 2 2 1 0 2 2 1 0 2 +sgemm_ S N N P U 3 3 3 1 0 3 3 1 0 3 +sgemm_ S N N P U 4 4 4 1 0 4 4 1 0 4 +sgemm_ S N N P U 5 5 5 1 0 5 5 1 0 5 +sgemm_ S N N P U 6 6 6 1 0 6 6 1 0 6 +sgemm_ S N N P U 7 7 7 1 0 7 7 1 0 7 +sgemm_ S N N P U 8 8 8 1 0 8 8 1 0 8 +sgemm_ S N N P U 9 9 9 1 0 9 9 1 0 9 +sgemm_ S N N P U 10 10 10 1 0 10 10 1 0 10 +sgemm_ S N N P U 20 20 20 1 0 20 20 1 0 20 +sgemm_ S N N P U 30 30 30 1 0 30 30 1 0 30 +sgemm_ S N N P U 40 40 40 1 0 40 40 1 0 40 +sgemm_ S N N P U 50 50 50 1 0 50 50 1 0 50 +sgemm_ S N N P U 60 60 60 1 0 60 60 1 0 60 +sgemm_ S N N P U 70 70 70 1 0 70 70 1 0 70 +sgemm_ S N N P U 80 80 80 1 0 80 80 1 0 80 +sgemm_ S N N P U 90 90 90 1 0 90 90 1 0 90 +sgemm_ S N N P U 100 100 100 1 0 100 100 1 0 100 +sgemm_ S N N P U 200 200 200 1 0 200 200 1 0 200 +sgemm_ S N N P U 300 300 300 1 0 300 300 1 0 300 +sgemm_ S N N P U 400 400 400 1 0 400 400 1 0 400 +sgemm_ S N N P U 500 500 500 1 0 500 500 1 0 500 +dgemm_ D N N P U 1 1 1 1 0 1 1 1 0 1 +dgemm_ D N N P U 2 2 2 1 0 2 2 1 0 2 +dgemm_ D N N P U 3 3 3 1 0 3 3 1 0 3 +dgemm_ D N N P U 4 4 4 1 0 4 4 1 0 4 +dgemm_ D N N P U 5 5 5 1 0 5 5 1 0 5 +dgemm_ D N N P U 6 6 6 1 0 6 6 1 0 6 +dgemm_ D N N P U 7 7 7 1 0 7 7 1 0 7 +dgemm_ D N N P U 8 8 8 1 0 8 8 1 0 8 +dgemm_ D N N P U 9 9 9 1 0 9 9 1 0 9 +dgemm_ D N N P U 10 10 10 1 0 10 10 1 0 10 +dgemm_ D N N P U 20 20 20 1 0 20 20 1 0 20 +dgemm_ D N N P U 30 30 30 1 0 30 30 1 0 30 +dgemm_ D N N P U 40 40 40 1 0 40 40 1 0 40 +dgemm_ D N N P U 50 50 50 1 0 50 50 1 0 50 +dgemm_ D N N P U 60 60 60 1 0 60 60 1 0 60 +dgemm_ D N N P U 70 70 70 1 0 70 70 1 0 70 +dgemm_ D N N P U 80 80 80 1 0 80 80 1 0 80 +dgemm_ D N N P U 90 90 90 1 0 90 90 1 0 90 +dgemm_ D N N P U 100 100 100 1 0 100 100 1 0 100 +dgemm_ D N N P U 200 200 200 1 0 200 200 1 0 200 +dgemm_ D N N P U 300 300 300 1 0 300 300 1 0 300 +dgemm_ D N N P U 400 400 400 1 0 400 400 1 0 400 +dgemm_ D N N P U 500 500 500 1 0 500 500 1 0 500 +sgemm_ S N N U P 1 1 1 1 0 1 1 1 0 1 +sgemm_ S N N U P 2 2 2 1 0 2 2 1 0 2 +sgemm_ S N N U P 3 3 3 1 0 3 3 1 0 3 +sgemm_ S N N U P 4 4 4 1 0 4 4 1 0 4 +sgemm_ S N N U P 5 5 5 1 0 5 5 1 0 5 +sgemm_ S N N U P 6 6 6 1 0 6 6 1 0 6 +sgemm_ S N N U P 7 7 7 1 0 7 7 1 0 7 +sgemm_ S N N U P 8 8 8 1 0 8 8 1 0 8 +sgemm_ S N N U P 9 9 9 1 0 9 9 1 0 9 +sgemm_ S N N U P 10 10 10 1 0 10 10 1 0 10 +sgemm_ S N N U P 20 20 20 1 0 20 20 1 0 20 +sgemm_ S N N U P 30 30 30 1 0 30 30 1 0 30 +sgemm_ S N N U P 40 40 40 1 0 40 40 1 0 40 +sgemm_ S N N U P 50 50 50 1 0 50 50 1 0 50 +sgemm_ S N N U P 60 60 60 1 0 60 60 1 0 60 +sgemm_ S N N U P 70 70 70 1 0 70 70 1 0 70 +sgemm_ S N N U P 80 80 80 1 0 80 80 1 0 80 +sgemm_ S N N U P 90 90 90 1 0 90 90 1 0 90 +sgemm_ S N N U P 100 100 100 1 0 100 100 1 0 100 +sgemm_ S N N U P 200 200 200 1 0 200 200 1 0 200 +sgemm_ S N N U P 300 300 300 1 0 300 300 1 0 300 +sgemm_ S N N U P 400 400 400 1 0 400 400 1 0 400 +sgemm_ S N N U P 500 500 500 1 0 500 500 1 0 500 +dgemm_ D N N U P 1 1 1 1 0 1 1 1 0 1 +dgemm_ D N N U P 2 2 2 1 0 2 2 1 0 2 +dgemm_ D N N U P 3 3 3 1 0 3 3 1 0 3 +dgemm_ D N N U P 4 4 4 1 0 4 4 1 0 4 +dgemm_ D N N U P 5 5 5 1 0 5 5 1 0 5 +dgemm_ D N N U P 6 6 6 1 0 6 6 1 0 6 +dgemm_ D N N U P 7 7 7 1 0 7 7 1 0 7 +dgemm_ D N N U P 8 8 8 1 0 8 8 1 0 8 +dgemm_ D N N U P 9 9 9 1 0 9 9 1 0 9 +dgemm_ D N N U P 10 10 10 1 0 10 10 1 0 10 +dgemm_ D N N U P 20 20 20 1 0 20 20 1 0 20 +dgemm_ D N N U P 30 30 30 1 0 30 30 1 0 30 +dgemm_ D N N U P 40 40 40 1 0 40 40 1 0 40 +dgemm_ D N N U P 50 50 50 1 0 50 50 1 0 50 +dgemm_ D N N U P 60 60 60 1 0 60 60 1 0 60 +dgemm_ D N N U P 70 70 70 1 0 70 70 1 0 70 +dgemm_ D N N U P 80 80 80 1 0 80 80 1 0 80 +dgemm_ D N N U P 90 90 90 1 0 90 90 1 0 90 +dgemm_ D N N U P 100 100 100 1 0 100 100 1 0 100 +dgemm_ D N N U P 200 200 200 1 0 200 200 1 0 200 +dgemm_ D N N U P 300 300 300 1 0 300 300 1 0 300 +dgemm_ D N N U P 400 400 400 1 0 400 400 1 0 400 +dgemm_ D N N U P 500 500 500 1 0 500 500 1 0 500 diff --git a/blastest/CMakeLists.txt b/blastest/CMakeLists.txt index d35629f53a..c8a653c2fa 100644 --- a/blastest/CMakeLists.txt +++ b/blastest/CMakeLists.txt @@ -1,13 +1,133 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## -set(F2C_LIB "libf2c") +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +if(NOT DEFINED BLIS_INSTALL_PATH) + set(DIST_PATH ${CMAKE_BINARY_DIR}) + set(LIB_PATH ${DIST_PATH}/lib/${BLIS_CONFIG_FAMILY}) + set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) +else() + set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/blis) +endif() -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/f2c) +# Include the corresponding make_defs.cmake that holds the required compiler options. +include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) -# Generate F2C library -add_library("${F2C_LIB}" STATIC ) -set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C) +# Create a static library using the sources in f2c directory. +file(GLOB f2c_sources LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/f2c/*.c) +add_library(f2c STATIC ${f2c_sources}) +target_compile_options(f2c + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + ${CWARNFLAGS} + ${CPICFLAGS} + ${CMISCFLAGS} + ${CLANGFLAGS} + # Suppress warnings about uninitialized functions + -Wno-maybe-uninitialized -Wno-parentheses -Wfatal-errors + ) +target_compile_definitions(f2c + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + ${CPPROCFLAGS} + -DHAVE_BLIS_H + ) +target_include_directories(f2c + BEFORE + PRIVATE + # Add local header paths + ${CMAKE_CURRENT_SOURCE_DIR}/f2c + # and the path to blis.h + ${INC_PATH} + ) +target_link_libraries(f2c PRIVATE ${LDFLAGS}) +if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(f2c PRIVATE OpenMP::OpenMP_C) +endif() +# Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. +set_target_properties(f2c PROPERTIES FOLDER blastest-targets) +add_dependencies(f2c flat-header) +# Gather all local source files. +file(GLOB blastest_sources LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/src/*.c) +list(TRANSFORM blastest_sources REPLACE ${CMAKE_CURRENT_SOURCE_DIR}/src/ "") -add_subdirectory(f2c) -add_subdirectory(src) +# Create one executable for each of the sources. +foreach(source ${blastest_sources}) + string(REPLACE .c "" exec_name ${source}) + add_executable(${exec_name}.x src/${source}) + target_compile_options(${exec_name}.x + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + ${CWARNFLAGS} + ${CPICFLAGS} + ${CMISCFLAGS} + ${CLANGFLAGS} + # Suppress warnings about uninitialized functions + -Wno-parentheses -Wno-maybe-uninitialized + ) + target_compile_definitions(${exec_name}.x + PRIVATE + # in get-noopt-cflags-for + ${VERS_DEF} + ${CPPROCFLAGS} + -DHAVE_BLIS_H + ) + target_include_directories(${exec_name}.x + BEFORE + PRIVATE + # Add local header paths + ${CMAKE_CURRENT_SOURCE_DIR}/f2c + # and the path to blis.h + ${INC_PATH} + ) + target_link_libraries(${exec_name}.x PRIVATE f2c libblis ${LDFLAGS}) + if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(${exec_name}.x PRIVATE OpenMP::OpenMP_C) + endif() + set_target_properties(${exec_name}.x PROPERTIES CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + # Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. + set_target_properties(${exec_name}.x PROPERTIES FOLDER blastest-targets) + # Add a target for running the tests. Rules are different for level-1 APIs, compared to levels 2 and 3. + if(${exec_name} MATCHES 1) + add_custom_target(run-${exec_name} + COMMAND ${exec_name}.x > out.${exec_name} + COMMENT "Running ${exec_name}.x with output redirected to out.${exec_name}" + DEPENDS ${exec_name}.x + BYPRODUCTS ${CMAKE_BINARY_DIR}/out.${exec_name} + WORKING_DIRECTORY $ + VERBATIM + ) + else()# name has 2 or 3 + add_custom_target(run-${exec_name} + COMMAND ${exec_name}.x < ${CMAKE_CURRENT_SOURCE_DIR}/input/${exec_name}.in + COMMENT "Running ${exec_name}.x with input ${CMAKE_CURRENT_SOURCE_DIR}/input/${exec_name}.in and output saved to out.${exec_name}" + DEPENDS ${exec_name}.x + BYPRODUCTS ${CMAKE_BINARY_DIR}/out.${exec_name} + WORKING_DIRECTORY $ + VERBATIM + ) + endif() + # Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. + set_target_properties(run-${exec_name} PROPERTIES FOLDER blastest-targets) + list(APPEND test_executables "run-${exec_name}") +endforeach() + +add_custom_target(testblas DEPENDS ${test_executables}) +add_custom_target(checkblas + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/check-blastest.py "." + DEPENDS testblas + WORKING_DIRECTORY $ + ) +# Put all those targets under blastest-targets-targets folder name so that they appear all together in IDE. +set_target_properties(testblas checkblas PROPERTIES FOLDER blastest-targets) \ No newline at end of file diff --git a/blastest/f2c/CMakeLists.txt b/blastest/f2c/CMakeLists.txt deleted file mode 100644 index 87ec3b6a5b..0000000000 --- a/blastest/f2c/CMakeLists.txt +++ /dev/null @@ -1,59 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${F2C_LIB}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/abs.c - ${CMAKE_CURRENT_SOURCE_DIR}/acos.c - ${CMAKE_CURRENT_SOURCE_DIR}/asin.c - ${CMAKE_CURRENT_SOURCE_DIR}/atan.c - ${CMAKE_CURRENT_SOURCE_DIR}/atn2.c - ${CMAKE_CURRENT_SOURCE_DIR}/close.c - ${CMAKE_CURRENT_SOURCE_DIR}/cnjg.c - ${CMAKE_CURRENT_SOURCE_DIR}/cos.c - ${CMAKE_CURRENT_SOURCE_DIR}/cosh.c - ${CMAKE_CURRENT_SOURCE_DIR}/dim.c - ${CMAKE_CURRENT_SOURCE_DIR}/div.c - ${CMAKE_CURRENT_SOURCE_DIR}/dolio.c - ${CMAKE_CURRENT_SOURCE_DIR}/endfile.c - ${CMAKE_CURRENT_SOURCE_DIR}/epsilon.c - ${CMAKE_CURRENT_SOURCE_DIR}/err.c - ${CMAKE_CURRENT_SOURCE_DIR}/exit_.c - ${CMAKE_CURRENT_SOURCE_DIR}/exp.c - ${CMAKE_CURRENT_SOURCE_DIR}/fmt.c - ${CMAKE_CURRENT_SOURCE_DIR}/fmtlib.c - ${CMAKE_CURRENT_SOURCE_DIR}/h_dnnt.c - ${CMAKE_CURRENT_SOURCE_DIR}/hl_cmp.c - ${CMAKE_CURRENT_SOURCE_DIR}/i_dnnt.c - ${CMAKE_CURRENT_SOURCE_DIR}/i_len.c - ${CMAKE_CURRENT_SOURCE_DIR}/imag.c - ${CMAKE_CURRENT_SOURCE_DIR}/int.c - ${CMAKE_CURRENT_SOURCE_DIR}/l_cmp.c - ${CMAKE_CURRENT_SOURCE_DIR}/lg10.c - ${CMAKE_CURRENT_SOURCE_DIR}/log.c - ${CMAKE_CURRENT_SOURCE_DIR}/lread.c - ${CMAKE_CURRENT_SOURCE_DIR}/lwrite.c - ${CMAKE_CURRENT_SOURCE_DIR}/mod.c - ${CMAKE_CURRENT_SOURCE_DIR}/nint.c - ${CMAKE_CURRENT_SOURCE_DIR}/open.c - ${CMAKE_CURRENT_SOURCE_DIR}/pow.c - ${CMAKE_CURRENT_SOURCE_DIR}/prod.c - ${CMAKE_CURRENT_SOURCE_DIR}/rdfmt.c - ${CMAKE_CURRENT_SOURCE_DIR}/rewind.c - ${CMAKE_CURRENT_SOURCE_DIR}/rsfe.c - ${CMAKE_CURRENT_SOURCE_DIR}/s_cmp.c - ${CMAKE_CURRENT_SOURCE_DIR}/s_copy.c - ${CMAKE_CURRENT_SOURCE_DIR}/s_stop.c - ${CMAKE_CURRENT_SOURCE_DIR}/sfe.c - ${CMAKE_CURRENT_SOURCE_DIR}/sig_die.c - ${CMAKE_CURRENT_SOURCE_DIR}/sign.c - ${CMAKE_CURRENT_SOURCE_DIR}/sin.c - ${CMAKE_CURRENT_SOURCE_DIR}/sinh.c - ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.c - ${CMAKE_CURRENT_SOURCE_DIR}/tan.c - ${CMAKE_CURRENT_SOURCE_DIR}/tanh.c - ${CMAKE_CURRENT_SOURCE_DIR}/util.c - ${CMAKE_CURRENT_SOURCE_DIR}/wref.c - ${CMAKE_CURRENT_SOURCE_DIR}/wrtfmt.c - ${CMAKE_CURRENT_SOURCE_DIR}/wsfe.c - ${CMAKE_CURRENT_SOURCE_DIR}/wsle.c - ) diff --git a/blastest/f2c/open.c b/blastest/f2c/open.c index 2834fd9463..12e5f02b21 100644 --- a/blastest/f2c/open.c +++ b/blastest/f2c/open.c @@ -28,6 +28,7 @@ use or performance of this software. #include #endif #ifdef _MSC_VER +#include #define access _access #endif #include "f2c.h" diff --git a/blastest/src/CMakeLists.txt b/blastest/src/CMakeLists.txt deleted file mode 100644 index 69274a5547..0000000000 --- a/blastest/src/CMakeLists.txt +++ /dev/null @@ -1,37 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -add_executable(cblat1 cblat1.c) -target_link_libraries(cblat1 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(cblat2 cblat2.c) -target_link_libraries(cblat2 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(cblat3 cblat3.c) -target_link_libraries(cblat3 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(dblat1 dblat1.c) -target_link_libraries(dblat1 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(dblat2 dblat2.c) -target_link_libraries(dblat2 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(dblat3 dblat3.c) -target_link_libraries(dblat3 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(sblat1 sblat1.c) -target_link_libraries(sblat1 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(sblat2 sblat2.c) -target_link_libraries(sblat2 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(sblat3 sblat3.c) -target_link_libraries(sblat3 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(zblat1 zblat1.c) -target_link_libraries(zblat1 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(zblat2 zblat2.c) -target_link_libraries(zblat2 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) - -add_executable(zblat3 zblat3.c) -target_link_libraries(zblat3 PRIVATE "${F2C_LIB}" "${LIB_NAME}.lib" ) diff --git a/build/auto_config.py b/build/auto_config.py index 1ce3989e4e..8b39944899 100644 --- a/build/auto_config.py +++ b/build/auto_config.py @@ -1,4 +1,4 @@ -"""Copyright (C) 2020, Advanced Micro Devices, Inc. All Rights Reserved""" +"""Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.""" import subprocess import sys diff --git a/build/bli_config.h.in b/build/bli_config.h.in index ba0c16100b..1e10616246 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/bli_win_config.h.in b/build/bli_win_config.h.in deleted file mode 100644 index 4645b5cf95..0000000000 --- a/build/bli_win_config.h.in +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. - */ - -#ifndef BLIS_CONFIG_H -#define BLIS_CONFIG_H - -#cmakedefine AOCL_DYNAMIC - -#cmakedefine AOCL_BLIS_ZEN - -#cmakedefine BLIS_ENABLE_OPENMP - -#cmakedefine BLIS_ENABLE_JRIR_SLAB - -#cmakedefine BLIS_ENABLE_JRIR_RR - -#cmakedefine BLIS_ENABLE_PBA_POOLS - -#cmakedefine BLIS_ENABLE_SBA_POOLS - -#cmakedefine BLIS_ENABLE_MEM_TRACING - -#cmakedefine BLIS_INT_TYPE_SIZE @INT_TYPE_SIZE@ - -#cmakedefine BLIS_BLAS_INT_TYPE_SIZE @BLAS_INT_TYPE_SIZE@ - -#cmakedefine BLIS_ENABLE_BLAS - -#cmakedefine BLIS_ENABLE_CBLAS - -#cmakedefine BLIS_ENABLE_MIXED_DT - -#cmakedefine BLIS_ENABLE_MIXED_DT_EXTRA_MEM - -#cmakedefine BLIS_ENABLE_SUP_HANDLING - -#cmakedefine BLIS_ENABLE_MEMKIND - -#cmakedefine BLIS_ENABLE_TRSM_PREINVERSION - -#cmakedefine BLIS_ENABLE_PRAGMA_OMP_SIMD - -#cmakedefine BLIS_ENABLE_SANDBOX - -#cmakedefine BLIS_ENABLE_SHARED - -#cmakedefine BLIS_ENABLE_COMPLEX_RETURN_INTEL - -#cmakedefine DISABLE_BLIS_ARCH_TYPE - -#cmakedefine DISABLE_BLIS_MODEL_TYPE - -#cmakedefine __blis_arch_type_name "@rename_blis_arch_type@" - -#cmakedefine __blis_model_type_name "@rename_blis_model_type@" - -#endif diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index f49d101ae7..2f28a4c088 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -1,4 +1,4 @@ -"""Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All Rights Reserved""" +"""Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved.""" ################################################################################ # This file is used to mirroring the refkernels folder data into to zen, zen2, # diff --git a/build/cmake/bli_addon.h.in b/build/cmake/bli_addon.h.in new file mode 100644 index 0000000000..b002b43619 --- /dev/null +++ b/build/cmake/bli_addon.h.in @@ -0,0 +1,17 @@ +/* + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + */ + +#ifndef BLIS_ADDON_H +#define BLIS_ADDON_H + +#if ${ENABLE_ADDONS_01} +#define BLIS_ENABLE_ADDONS +#else +#define BLIS_DISABLE_ADDONS +#endif + +// Enabled addons +${ADDON_LIST_INCLUDES} + +#endif diff --git a/build/cmake/bli_config.h.in b/build/cmake/bli_config.h.in new file mode 100644 index 0000000000..aed543b868 --- /dev/null +++ b/build/cmake/bli_config.h.in @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + */ + +#ifndef BLIS_CONFIG_H +#define BLIS_CONFIG_H + +// Enabled configuration "family" (config_name) +${CONFIG_NAME_DEFINE} + +// Enabled sub-configurations (config_list) +${CONFIG_LIST_DEFINES} + +// Enabled kernel sets (kernel_list) +${KERNEL_LIST_DEFINES} + +//This macro is enabled only for ZEN family configurations. +//This enables us to use different cache-blocking sizes for TRSM instead of common level-3 cache-block sizes. +#if ${ENABLE_AOCL_ZEN_01} +#define AOCL_BLIS_ZEN +#endif + +#if ${ENABLE_AOCL_DYNAMIC_01} +#define AOCL_DYNAMIC +#endif + +#if ${ENABLE_SYSTEM_01} +#define BLIS_ENABLE_SYSTEM +#else +#define BLIS_DISABLE_SYSTEM +#endif + +#if ${ENABLE_OPENMP_01} +#define BLIS_ENABLE_OPENMP +#endif + +#if ${ENABLE_PTHREADS_01} +#define BLIS_ENABLE_PTHREADS +#endif + +#if ${ENABLE_JRIR_SLAB_01} +#define BLIS_ENABLE_JRIR_SLAB +#endif + +#if ${ENABLE_JRIR_RR_01} +#define BLIS_ENABLE_JRIR_RR +#endif + +#if ${ENABLE_PBA_POOLS_01} +#define BLIS_ENABLE_PBA_POOLS +#else +#define BLIS_DISABLE_PBA_POOLS +#endif + +#if ${ENABLE_SBA_POOLS_01} +#define BLIS_ENABLE_SBA_POOLS +#else +#define BLIS_DISABLE_SBA_POOLS +#endif + +#if ${ENABLE_MEM_TRACING_01} +#define BLIS_ENABLE_MEM_TRACING +#else +#define BLIS_DISABLE_MEM_TRACING +#endif + +#if ${INT_TYPE_SIZE} == 64 +#define BLIS_INT_TYPE_SIZE 64 +#elif ${INT_TYPE_SIZE} == 32 +#define BLIS_INT_TYPE_SIZE 32 +#else +// determine automatically +#endif + +#if ${BLAS_INT_TYPE_SIZE} == 64 +#define BLIS_BLAS_INT_TYPE_SIZE 64 +#elif ${BLAS_INT_TYPE_SIZE} == 32 +#define BLIS_BLAS_INT_TYPE_SIZE 32 +#else +// determine automatically +#endif + +#ifndef BLIS_ENABLE_BLAS +#ifndef BLIS_DISABLE_BLAS +#if ${ENABLE_BLAS_01} +#define BLIS_ENABLE_BLAS +#else +#define BLIS_DISABLE_BLAS +#endif +#endif +#endif + +#ifndef BLIS_ENABLE_CBLAS +#ifndef BLIS_DISABLE_CBLAS +#if ${ENABLE_CBLAS_01} +#define BLIS_ENABLE_CBLAS +#else +#define BLIS_DISABLE_CBLAS +#endif +#endif +#endif + +// If the CBLAS compatibility layer was enabled while the BLAS layer +// was not enabled, we must enable the BLAS layer here. Also undefine +// BLIS_DISABLE_BLAS to ensure consistency. +#ifdef BLIS_ENABLE_CBLAS +#ifndef BLIS_ENABLE_BLAS +#define BLIS_ENABLE_BLAS +#endif +#undef BLIS_DISABLE_BLAS +#endif // BLIS_ENABLE_CBLAS + +#ifndef BLIS_ENABLE_MIXED_DT +#ifndef BLIS_DISABLE_MIXED_DT +#if ${ENABLE_MIXED_DT_01} +#define BLIS_ENABLE_MIXED_DT +#else +#define BLIS_DISABLE_MIXED_DT +#endif +#endif +#endif + +#ifndef BLIS_ENABLE_MIXED_DT_EXTRA_MEM +#ifndef BLIS_DISABLE_MIXED_DT_EXTRA_MEM +#if ${ENABLE_MIXED_DT_EXTRA_MEM_01} +#define BLIS_ENABLE_MIXED_DT_EXTRA_MEM +#else +#define BLIS_DISABLE_MIXED_DT_EXTRA_MEM +#endif +#endif +#endif + +#if ${ENABLE_SUP_HANDLING_01} +#define BLIS_ENABLE_SUP_HANDLING +#else +#define BLIS_DISABLE_SUP_HANDLING +#endif + +#if ${ENABLE_MEMKIND_01} +#define BLIS_ENABLE_MEMKIND +#else +#define BLIS_DISABLE_MEMKIND +#endif + +#if ${ENABLE_TRSM_PREINVERSION_01} +#define BLIS_ENABLE_TRSM_PREINVERSION +#else +#define BLIS_DISABLE_TRSM_PREINVERSION +#endif + +#if ${ENABLE_PRAGMA_OMP_SIMD_01} +#define BLIS_ENABLE_PRAGMA_OMP_SIMD +#else +#define BLIS_DISABLE_PRAGMA_OMP_SIMD +#endif + +#if ${ENABLE_SANDBOX_01} +#define BLIS_ENABLE_SANDBOX +#else +#define BLIS_DISABLE_SANDBOX +#endif + +#if ${ENABLE_SHARED_01} +#define BLIS_ENABLE_SHARED +#else +#define BLIS_DISABLE_SHARED +#endif + +#if ${COMPLEX_RETURN_INTEL_01} +#define BLIS_ENABLE_COMPLEX_RETURN_INTEL +#else +#define BLIS_DISABLE_COMPLEX_RETURN_INTEL +#endif + +#if ${DISABLE_BLIS_ARCH_TYPE_01} +#define DISABLE_BLIS_ARCH_TYPE +#define DISABLE_BLIS_MODEL_TYPE +#endif + +#define __blis_arch_type_name "${RENAME_BLIS_ARCH_TYPE}" +#define __blis_model_type_name "${RENAME_BLIS_MODEL_TYPE}" + +#endif diff --git a/build/cmake/check-blastest.py b/build/cmake/check-blastest.py new file mode 100644 index 0000000000..8e1123cf80 --- /dev/null +++ b/build/cmake/check-blastest.py @@ -0,0 +1,31 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Import modules +import os +import sys + +def check_blastest(): + results_file_path = sys.argv[1] + results_directory = os.listdir(results_file_path) + has_failure = False + is_empty = False + for fname in results_directory: + if os.path.isfile(results_file_path + os.sep + fname) and "out" in fname: + file = open(results_file_path + os.sep + fname, 'r') + # read all content of a file + content = file.read() + if content == "": + is_empty = True + # check if string present in a file + if "*****" in content: + has_failure = True + if has_failure: + print("\033[0;31m At least one BLAS test failed. :( \033[0m") + print("\033[0;31m Please see the corresponding out.* for details. \033[0m") + elif is_empty: + print("\033[0;31m At least one BLAS test resulted without a PASS. :( \033[0m") + print("\033[0;31m Please ensure that the corresponding out.* was generated correctly. \033[0m") + else: + print("\033[0;32m All BLAS tests passed! \033[0m") + +check_blastest() diff --git a/build/cmake/check-blistest.py b/build/cmake/check-blistest.py new file mode 100644 index 0000000000..983f8e8241 --- /dev/null +++ b/build/cmake/check-blistest.py @@ -0,0 +1,22 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Import modules +import os +import sys + +def check_blistest(): + results_file = sys.argv[1] + with open(results_file, 'r') as file: + # read all content of a file + content = file.read() + # check if string present in a file + if "FAILURE" in content: + print("\033[0;31m At least one BLIS test failed. :( \033[0m") + print("\033[0;31m Please see the corresponding output.testsuite* for details. \033[0m") + elif not "PASS" in content: + print("\033[0;31m No BLIS test resulted in PASS. :( \033[0m") + print("\033[0;31m Please ensure that the corresponding output.testsuite* was generated correctly. \033[0m") + else: + print("\033[0;32m All BLIS tests passed! \033[0m") + +check_blistest() diff --git a/build/cmake/config_print.py b/build/cmake/config_print.py new file mode 100644 index 0000000000..f5fc767711 --- /dev/null +++ b/build/cmake/config_print.py @@ -0,0 +1,306 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Import modules +import os +import sys + +def main(): + # Obtain the script name. + path, script_name = os.path.split(sys.argv[0]) + print( " " ) + print( " %s" % script_name ) + print( " " ) + print( " Configure BLIS's CMake system for compilation using a specified" ) + print( " configuration directory." ) + print( " " ) + print( " Usage:" ) + print( " " ) + print( " cmake .. [Options] -DBLIS_CONFIG_FAMILY=confname" ) + print( " " ) + print(" Arguments:") + print(" ") + print(" confname The name of the sub-directory inside of the 'config'") + print(" directory containing the desired BLIS configuration.") + print(" Currently, only amdzen, zen, zen2, zen3, zen4 and generic") + print(" configuration options are supported.") + print(" Note that confname MUST be specified; if it is not,") + print(" configure will complain. To build a completely generic") + print(" implementation, use the 'generic' configuration.") + print(" ") + print( " Options:" ) + print( " " ) + print( " -DCMAKE_INSTALL_PREFIX=PREFIX" ) + print( " " ) + print( " The common installation prefix for all files." ) + print( " If this option is not given, PREFIX defaults to '/usr/local/'." ) + print( " on UNIX and c:/Program Files/${PROJECT_NAME} on Windows." ) + print( " " ) + print( " -DENABLE_DEBUG=DEBUG" ) + print( " " ) + print( " Enable debugging symbols in the library." ) + print( " DEBUG is 'off' by default. If argument" ) + print( " DEBUG is given as 'opt', then optimization flags are" ) + print( " kept in the framework, otherwise optimization is" ) + print( " turned off. Available options are 'opt', 'noopt' and 'off'." ) + print( " " ) + print( " --disable-static, --enable-static" ) + print( " " ) + print( " Disable (enabled by default) building BLIS as a static" ) + print( " library. If the static library build is disabled, the" ) + print( " shared library build must remain enabled." ) + print( " " ) + print( " --disable-shared, --enable-shared" ) + print( " " ) + print( " Disable (enabled by default) building BLIS as a shared" ) + print( " library. If the shared library build is disabled, the" ) + print( " static library build must remain enabled." ) + print( " " ) + print( " -DEXPORT_SHARED=[SYMBOLS]" ) + print( " " ) + print( " Specify the subset of library symbols that are exported" ) + print( " within a shared library. Valid values for SYMBOLS are:" ) + print( " 'public' (the default) and 'all'. By default, only" ) + print( " functions and variables that belong to public APIs are" ) + print( " exported in shared libraries. However, the user may" ) + print( " instead export all symbols in BLIS, even those that were" ) + print( " intended for internal use only. Note that the public APIs" ) + print( " encompass all functions that almost any user would ever" ) + print( " want to call, including the BLAS/CBLAS compatibility APIs" ) + print( " as well as the basic and expert interfaces to the typed" ) + print( " and object APIs that are unique to BLIS. Also note that" ) + print( " changing this option to 'all' will have no effect in some" ) + print( " environments, such as when compiling with clang on" ) + print( " Windows." ) + print( " " ) + print( " -DENABLE_THREADING=MODEL" ) + print( " " ) + print( " Enable threading in the library, using threading model" ) + print( " MODEL={openmp, pthreads, no}. If MODEL=no threading will be" ) + print( " disabled. The default is 'no'." ) + print( " " ) + print( " -DENABLE_SYSTEM=ON or -DENABLE_SYSTEM=OFF") + print( " " ) + print( " Enable conventional operating system support, such as" ) + print( " pthreads for thread-safety. The default state is enabled." ) + print( " However, in rare circumstances you may wish to configure" ) + print( " BLIS for use with a minimal or nonexistent operating" ) + print( " system (e.g. hardware simulators). In these situations," ) + print( " -DENABLE_SYSTEM=OFF may be used to jettison all compile-time" ) + print( " and link-time dependencies outside of the standard C" ) + print( " library. When disabled, this option also forces the use" ) + print( " of -DENABLE_THREADING=no." ) + print( " " ) + print( " -DENABLE_PBA_POOLS=ON or -DENABLE_PBA_POOLS=OFF" ) + print( " -DENABLE_SBA_POOLS=ON or -DENABLE_SBA_POOLS=OFF" ) + print( " " ) + print( " Disable (enabled by default) use of internal memory pools" ) + print( " within the packing block allocator (pba) and/or the small" ) + print( " block allocator (sba). The former is used to allocate" ) + print( " memory used to pack submatrices while the latter is used" ) + print( " to allocate control/thread tree nodes and thread" ) + print( " communicators. Both allocations take place in the context" ) + print( " of level-3 operations. When the pba is disabled, the" ) + print( " malloc()-like function specified by BLIS_MALLOC_POOL is" ) + print( " called on-demand whenever a packing block is needed, and" ) + print( " when the sba is disabled, the malloc()-like function" ) + print( " specified by BLIS_MALLOC_INTL is called whenever a small" ) + print( " block is needed, with the two allocators calling free()-" ) + print( " like functions BLIS_FREE_POOL and BLIS_FREE_INTL," ) + print( " respectively when blocks are released. When enabled," ) + print( " either or both pools are populated via the same functions" ) + print( " mentioned previously, and henceforth blocks are checked" ) + print( " out and in. The library quickly reaches a state in which" ) + print( " it no longer needs to call malloc() or free(), even" ) + print( " across many separate level-3 operation invocations." ) + print( " " ) + print( " -DENABLE_MEM_TRACING=ON or -DENABLE_MEM_TRACING=OFF" ) + print( " " ) + print( " Enable (disable by default) output to stdout that traces" ) + print( " the allocation and freeing of memory, including the names" ) + print( " of the functions that triggered the allocation/freeing." ) + print( " Enabling this option WILL NEGATIVELY IMPACT PERFORMANCE." ) + print( " Please use only for informational/debugging purposes." ) + print( " " ) + print( " -DINT_SIZE=SIZE" ) + print( " " ) + print( " Set the size (in bits) of internal BLIS integers and" ) + print( " integer types used in native BLIS interfaces. The" ) + print( " default integer type size is architecture dependent." ) + print( " (Hint: You can always find this value printed at the" ) + print( " beginning of the testsuite output.)" ) + print( " " ) + print( " -DBLAS_TYPE_SIZE=SIZE" ) + print( " " ) + print( " Set the size (in bits) of integer types in external" ) + print( " BLAS and CBLAS interfaces, if enabled. The default" ) + print( " integer type size used in BLAS/CBLAS is 32 bits." ) + print( " " ) + print( " -DENABLE_BLAS=ON or -DENABLE_BLAS=OFF" ) + print( " " ) + print( " Disable (enabled by default) building the BLAS" ) + print( " compatibility layer." ) + print( " " ) + print( " -DENABLE_CBLAS=ON or -DENABLE_CBLAS=OFF" ) + print( " " ) + print( " Enable (disabled by default) building the CBLAS" ) + print( " compatibility layer. This automatically enables the" ) + print( " BLAS compatibility layer as well." ) + print( " " ) + print( " -DENABLE_MIXED_DT=ON or -DENABLE_MIXED_DT=OFF" ) + print( " " ) + print( " Disable (enabled by default) support for mixing the" ) + print( " storage domain and/or storage precision of matrix" ) + print( " operands for the gemm operation, as well as support" ) + print( " for computing in a precision different from one or" ) + print( " both of matrices A and B." ) + print( " " ) + print( " -DENABLE_MIXED_DT_EXTRA_MEM=ON or -DENABLE_MIXED_DT_EXTRA_MEM=OFF" ) + print( " " ) + print( " Disable (enabled by default) support for additional" ) + print( " mixed datatype optimizations that require temporarily" ) + print( " allocating extra memory--specifically, a single m x n" ) + print( " matrix (per application thread) whose storage datatype" ) + print( " is equal to the computation datatype. This option may" ) + print( " only be enabled when mixed domain/precision support is" ) + print( " enabled." ) + print( " " ) + print( " -DENABLE_SUP_HANDLING=ON or -DENABLE_SUP_HANDLING=OFF" ) + print( " " ) + print( " Disable (enabled by default) handling of small/skinny" ) + print( " matrix problems via separate code branches. When disabled," ) + print( " these small/skinny level-3 operations will be performed by" ) + print( " the conventional implementation, which is optimized for" ) + print( " medium and large problems. Note that what qualifies as" ) + print( " \"small\" depends on thresholds that may vary by sub-" ) + print( " configuration." ) + print( " " ) + print( " -DENABLE_ADDON=\"NAME1[;NAME2;...]\" (Linux only)") + print( " " ) + print( " Enable the code provided by an addon. An addon consists" ) + print( " of a separate directory of code that provides additional" ) + print( " APIs, implementations, and/or operations that would" ) + print( " otherwise not be present within a build of BLIS." ) + print( " To enable a single addon named NAME1, set -DENABLE_ADDON=NAME1." ) + print( " To enable multiple addons, a ';'-separated list enclosed in \"\"") + print( " needs to be provided. For example, -DENABLE_ADDON=\"NAME1;NAME2\".") + print(" By default, no addons are enabled.") + print( " " ) + # Sandbox functionality is currently disabled in CMake. + #print( " -DENABLE_SANDBOX=NAME" ) + #print( " " ) + #print( " Enable a separate sandbox implementation of gemm. This" ) + #print( " option disables BLIS's conventional gemm implementation" ) + #print( " (which shares common infrastructure with other level-3" ) + #print( " operations) and instead compiles and uses the code in" ) + #print( " the NAME directory, which is expected to be a sub-" ) + #print( " directory of 'sandbox'. By default, no sandboxes are" ) + #print( " enabled." ) + #print( " " ) + print( " -DENABLE_MEMKIND=ON or -DENABLE_MEMKIND=OFF" ) + print( " " ) + print( " Forcibly enable or disable the use of libmemkind's" ) + print( " hbw_malloc() and hbw_free() as substitutes for malloc()" ) + print( " and free(), respectively, when allocating memory for" ) + print( " BLIS's memory pools, which are used to manage buffers" ) + print( " into which matrices are packed. The default behavior" ) + print( " for this option is environment-dependent; if configure" ) + print( " detects the presence of libmemkind, libmemkind is used" ) + print( " by default, and otherwise it is not used by default." ) + print( " " ) + print( " -DTHREAD_PART_JRIR=METHOD" ) + print( " " ) + print( " Request a method of assigning micropanels to threads in" ) + print( " the JR and IR loops. Valid values for METHOD are 'slab'" ) + print( " and 'rr'. Using 'slab' assigns (as much as possible)" ) + print( " contiguous regions of micropanels to each thread while" ) + print( " using 'rr' assigns micropanels to threads in a round-" ) + print( " robin fashion. The chosen method also applies during" ) + print( " the packing of A and B. The default method is 'slab'." ) + print( " NOTE: Specifying this option constitutes a request," ) + print( " which may be ignored in select situations if the" ) + print( " implementation has a good reason to do so." ) + print( " " ) + print( " -DENABLE_TRSM_PREINVERSION=ON or -DENABLE_TRSM_PREINVERSION=OFF" ) + print( " " ) + print( " Disable (enabled by default) pre-inversion of triangular" ) + print( " matrix diagonals when performing trsm. When pre-inversion" ) + print( " is enabled, diagonal elements are inverted outside of the" ) + print( " microkernel (e.g. during packing) so that the microkernel" ) + print( " can use multiply instructions. When disabled, division" ) + print( " instructions are used within the microkernel. Executing" ) + print( " these division instructions within the microkernel will" ) + print( " incur a performance penalty, but numerical robustness will" ) + print( " improve for certain cases involving denormal numbers that" ) + print( " would otherwise result in overflow in the pre-inverted" ) + print( " values." ) + print( " " ) + print( " -DFORCE_VERSION_STRING=STRING" ) + print( " " ) + print( " Force configure to use an arbitrary version string" ) + print( " STRING. This option may be useful when repackaging" ) + print( " custom versions of BLIS by outside organizations." ) + print( " " ) + print( " -DCOMPLEX_RETURN=gnu or -DCOMPLEX_RETURN=intel or -DCOMPLEX_RETURN=default" ) + print( " " ) + print( " Specify the way in which complex numbers are returned" ) + print( " from Fortran functions, either \"gnu\" (return in" ) + print( " registers) or \"intel\" (return via hidden argument)." ) + print( " By default COMPLEX_RETURNis set to 'default' and we" ) + print( " attempt to determine the return type from the compiler." ) + print( " Otherwise, the default is \"gnu\"." ) + print( " " ) + print( " -DENABLE_AOCL_DYNAMIC=ON or -DENABLE_AOCL_DYNAMIC=OFF" ) + print( " " ) + print( " Disable (Enabled by default) dynamic selection of number of" ) + print( " threads used to solve the given problem." ) + print( " Range of optimum number of threads will be [1, num_threads]," ) + print( " where \"num_threads\" is number of threads set by the application." ) + print( " Num_threads is derived from either environment variable" ) + print( " OMP_NUM_THREADS or BLIS_NUM_THREADS' or bli_set_num_threads() API." ) + print( " " ) + print( " -DDISABLE_BLIS_ARCH_TYPE=ON or -DDISABLE_BLIS_ARCH_TYPE=OFF" ) + print( " " ) + print( " Disable support for AOCL_ENABLE_INSTRUCTIONS, BLIS_ARCH_TYPE and" ) + print( " BLIS_MODEL_TYPE environment variables, which allows user to select" ) + print( " architecture specific code path and optimizations at runtime." ) + print( " If disabled, in builds with multiple code paths, BLIS" ) + print( " will still select path and optimizations automatically." ) + print( " Default: Enabled in builds with multiple code paths, else disabled." ) + print( " " ) + print( " -DRENAME_BLIS_ARCH_TYPE=STRING" ) + print( " " ) + print( " Change environment variable used to select architecture specific" ) + print( " code path from BLIS_ARCH_TYPE to STRING" ) + print( " " ) + print( " -DRENAME_BLIS_MODEL_TYPE=STRING" ) + print( " " ) + print( " Change environment variable used to select architecture model specific" ) + print( " optimizations from BLIS_MODEL_TYPE to STRING" ) + print( " " ) + print( " -DENABLE_NO_UNDERSCORE_API=OFF" ) + print( " " ) + print( " Export APIs without underscore" ) + print( " " ) + print( " -DENABLE_UPPERCASE_API=OFF" ) + print( " " ) + print( " Export APIs with uppercase" ) + print( " " ) + print( " " ) + print( " Additional CMake Variables:" ) + print( " " ) + print( " CMAKE_C_COMPILER Specifies the C compiler to use." ) + print( " CMAKE_CXX_COMPILER Specifies the C++ compiler to use (sandbox only)." ) + print( " CMAKE_Fortran_COMPILER Specifies the Fortran compiler to use (only to determine --complex-return)." ) + print( " COMPILE_OPTIONS Specifies additional compiler flags to use." ) + print( " COMPILE_DEFINITIONS Specifies additional preprocessor definitions to use." ) + print( " LINK_OPTIONS Specifies additional linker flags to use." ) + print( " " ) + print( " Note that not all compilers are compatible with a given" ) + print( " configuration." ) + + # Return from main(). + return 0 + + +if __name__ == "__main__": + main() diff --git a/build/cmake/read_registry.py b/build/cmake/read_registry.py new file mode 100644 index 0000000000..f8baf66378 --- /dev/null +++ b/build/cmake/read_registry.py @@ -0,0 +1,409 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Import modules +import os +import sys +import re + +def canonicalize_ws(str): + # Remove leading and trailing whitespace. + str = str.strip() + # Remove duplicate spaces between words. + res = " ".join(str.split()) + # Update the input argument. + return res + + +def is_singleton(str): + rval = False + count_str = " " + for item in str.split(): + count_str = count_str + "x" + if count_str == "x": + rval = True + return rval + + +def is_singleton_family(familyname, memberlist): + rval = False + if is_singleton(memberlist): + if memberlist == familyname: + rval = True + return rval + + +def is_in_list(word, str): + rval = False + for item in str.split(): + if item == word: + rval = True + break + return rval + + +def assign_key_value(array, key, value): + array.update({key: value}) + + +def query_array(array, key): + value = array.get(key) + return value + + +def remove_from_list(strike_words, list): + flist = "" + for item in list.split(): + # Filter out any list item that matches any of the strike words. + if not is_in_list(item, strike_words): + flist = " ".join([flist, item]) + flist = canonicalize_ws(flist) + # Return the filtered list. + return flist + +def replace_curconfig_configset(klisttmp, curconfig, configset): + tmplist = list(klisttmp.split(" ")) + ind = tmplist.index(curconfig) + tmplist.remove(curconfig) + tmplist.insert(ind, configset) + newlist = " ".join(map(str, tmplist)) + return newlist + +def rm_duplicate_words(str): + res = " ".join(str.split()[::-1]) + res = " ".join(dict.fromkeys(res.split())) + str = " ".join(res.split()[::-1]) + return str + +def pass_config_kernel_registries(filename, passnum): + global config_blist + global indirect_blist + global config_registry + global kernel_registry + # first argument: the file containing the configuration registry. + # second argument: the pass number: 0 or 1. Pass 0 builds the + # indirect config blacklist (indirect_blist) ONLY. Pass 1 actually + # begins populating the config and kernel registries, and assumes + # the indirect_blist has already been created. + # Initialize a list of indirect blacklisted configurations for the + # current iteration. These are configurations that are invalidated by + # the removal of blacklisted configurations. For example, if haswell + # is registered as needing the 'haswell' and 'zen' kernel sets: + # haswell: haswell/haswell/zen + # and 'zen' was blacklisted because of the compiler version, then the + # 'haswell' configuration must be omitted from the registry, as it no + # longer has all of the kernel sets it was expecting. + if passnum == 0: + indirect_blist = "" + # For convenience, merge the original and indirect blacklists. + # NOTE: During pass 0, all_blist is equal to config_blist, since + # indirect_blist is still empty. + all_blist = config_blist + indirect_blist + # Disable support for indirect blacklisting by returning early during + # pass 0. See issue #214 for details [1]. Basically, I realized that + # indirect blacklisting is not needed in the use case that I envisioned + # in the real-life example above. If a subconfiguration such as haswell + # is defined to require the zen kernel set, it implies that the zen + # kernels can be compiled with haswell compiler flags. That is, just + # because the zen subconfig (and its compiler flags) is blacklisted + # does not mean that the haswell subconfig cannot compile the zen + # kernels with haswell-specific flags. + # [1] https://github.com/flame/blis/issues/214 + if passnum == 0: + return + + cfg = open(filename, "r+") + while True: + line = cfg.readline() + if not line: + break + + # We've stripped out leading whitespace and trailing comments. If + # the line is now empty, then we can skip it altogether. + if re.match(r'\n', line) or re.match(r'#', line): + continue + + # Read the config name and config list for the current line. + cname, list = line.split(':') + cname = cname.strip() + list = list.strip() + # If we encounter a slash, it means the name of the configuration + # and the kernel set needed by that configuration are different. + if list.find("/") != -1: + clist = "" + klist = "" + # The sub-configuration name is always the first sub-word in + # the slash-separated compound word. + # Delete the sub-configuration name from the front of the + # string, leaving the slash-separated kernel names (or just + # the kernel name, if there is only one). + # Replace the slashes with spaces to transform the string + # into a space-separated list of kernel names. + list = list.replace("/", " ") + config, kernels = list.split(" ", 1) + + clist = clist + config + klist = klist + kernels + else: + clist = list + klist = list + + # Strip out whitespace from the config name and config/kernel list + # on each line. + cname = canonicalize_ws(cname) + clist = canonicalize_ws(clist) + klist = canonicalize_ws(klist) + # Next, we prepare to: + # - pass 0: inspect klist for blacklisted configurations, which may + # reveal configurations as needing to be indirectly blacklisted. + # - pass 1: compare cname to the blacklists and commit clist/klist + # to their respective registries, as appropriate. + # Handle singleton and umbrella configuration entries separately. + if is_singleton_family(cname, clist): + # Singleton configurations/families. + # Note: for singleton families, clist contains one item, which + # always equals cname, but klist could contain more than one + # item. + # Only consider updating the indirect blacklist (pass 0) or + # committing clist and klist to the registries (pass 1) if the + # configuration name (cname) is not blacklisted. + if not is_in_list(cname, all_blist): + if passnum == 0: + # Even if the cname isn't blacklisted, one of the requisite + # kernels might be, so we need to check klist for blacklisted + # items. If we find one, we must assume that the entire entry + # must be thrown out. (Ideally, we would simply fall back to + # reference code for the blacklisted kernels, but that is not + # at all straightforward under the current configuration + # system architecture.) Thus, we add cname to the indirect + # blacklist. + for item in klist.split(): + if is_in_list(item, config_blist): + indirect_blist = indirect_blist + cname + break + if passnum == 1: + # Store the clist to the cname key of the config registry. + # config_registry[${cname}]=${clist} + assign_key_value(config_registry, cname, clist) + if passnum == 1: + # Store the klist to the cname key of the kernel registry. + # kernel_registry[${cname}]=${klist} + assign_key_value(kernel_registry, cname, klist) + else: + # Umbrella configurations/families. + # First we check cname, which should generally not be blacklisted + # for umbrella families, but we check anyway just to be safe. + if not is_in_list(cname, all_blist): + if passnum == 1: + # Check each item in the clist and klist. (At this point, + # clist == klist.) If any sub-config is blacklisted, we + # omit it from clist and klist. + for item in clist.split(): + if is_in_list(item, all_blist): + clist = remove_from_list(item, clist) + klist = remove_from_list(item, klist) + # Store the config and kernel lists to entries that + # corresponds to the config name. + assign_key_value(config_registry, cname, clist) + assign_key_value(kernel_registry, cname, klist) + cfg.close() + if passnum == 0: + # Assign the final indirect blacklist (with whitespace removed). + indirect_blist = canonicalize_ws(indirect_blist) + + +def read_registry_file(filename): + global config_registry + global kernel_registry + # Execute an initial pass through the config_registry file so that + # we can accumulate a list of indirectly blacklisted configurations, + # if any. + pass_config_kernel_registries(filename, 0) + # Now that the indirect_blist has been created, make a second pass + # through the 'config_registry' file, this time creating the actual + # config and kernel registry data structures. + pass_config_kernel_registries(filename, 1) + # Now we must go back through the config_registry and subsitute any + # configuration families with their constituents' members. Each time + # one of these substitutions occurs, we set a flag that causes us to + # make one more pass. (Subsituting a singleton definition does not + # prompt additional iterations.) This process stops when a full pass + # does not result in any subsitution. + + iterate_again = 1 + while iterate_again == 1: + iterate_again = 0 + for cr_var in config_registry: + config = cr_var + clist = query_array(config_registry, config) + # The entries that define singleton families should never need any substitution. + if is_singleton_family(config, clist): + continue + for mem in clist.split(): + mems_mem = query_array(config_registry, mem) + # If mems_mem is empty string, then mem was not found as a key + # in the config list associative array. In that case, we continue + # and will echo an error later in the script. + if not (mems_mem and mems_mem.strip()): + continue + if mem != mems_mem: + clist = query_array(config_registry, config) + # Replace the current config with its constituent config set, + # canonicalize whitespace, and then remove duplicate config + # set names, if they exist. Finally, update the config registry + # with the new config list. + #newclist = replace_curconfig_configset(clist, mem, mems_mem) + newclist = re.sub(r"\b{}\b".format(mem), mems_mem, clist) + newclist = canonicalize_ws(newclist) + newclist = rm_duplicate_words(newclist) + assign_key_value(config_registry, config, newclist) + # Since we performed a substitution and changed the config + # list, mark the iteration flag to continue another round, + # but only if the config (mem) value is NOT present + # in the list of sub-configs. If it is present, then further + # substitution may not necessarily be needed this round. + if not is_in_list(mem, mems_mem): + iterate_again = 1 + # Similar to what we just did for the config_registry, we now iterate + # through the kernel_registry and substitute any configuration families + # in the kernel list (right side of ':') with the members of that + # family's kernel set. This process continues iteratively, as before, + # until all families have been replaced with singleton configurations' + # kernel sets. + iterate_again = 1 + while iterate_again == 1: + iterate_again = 0 + for kr_var in kernel_registry: + config = kr_var + klist = query_array(kernel_registry, config) + # The entries that define singleton families should never need + # any substitution. In the kernel registry, we know it's a + # singleton entry when the cname occurs somewhere in the klist. + # (This is slightly different than the same test in the config + # registry, where we test that clist is one word and that + # clist == cname.) + if is_in_list(config, klist): + # echo "debug: '${config}' not found in '${klist}'; skipping." + continue + for ker in klist.split(): + kers_ker = query_array(kernel_registry, ker) + # If kers_ker is empty string, then ker was not found as a key + # in the kernel registry. While not common, this can happen + # when ker identifies a kernel set that does not correspond to + # any configuration. (Example: armv7a and armv8a kernel sets are + # used by cortexa* configurations, but do not correspond to their + # own configurations.) + if not (kers_ker and kers_ker.strip()): + continue + # If the current config/kernel (ker) differs from its singleton kernel + # entry (kers_ker), then that singleton entry was specified to use + # a different configuration's kernel set. Thus, we need to replace the + # occurrence in the current config/kernel name with that of the kernel + # set it needs. + if ker != kers_ker: + klisttmp = query_array(kernel_registry, config) + # Replace the current config with its requisite kernels, + # canonicalize whitespace, and then remove duplicate kernel + # set names, if they exist. Finally, update the kernel registry + # with the new kernel list. + #newklist = replace_curconfig_configset(klisttmp, ker, kers_ker) + newklist = re.sub(r"\b{}\b".format(ker), kers_ker, klisttmp) + newklist = canonicalize_ws(newklist) + newklist = rm_duplicate_words(newklist) + assign_key_value(kernel_registry, config, newklist) + # Since we performed a substitution and changed the kernel + # list, mark the iteration flag to continue another round, + # unless we just substituted using a singleton family + # definition, in which case we don't necessarily need to + # iterate further this round. + if not is_in_list(ker, kers_ker): + iterate_again = 1 + + +def build_kconfig_registry(familyname): + global config_registry + global kernel_registry + global kconfig_registry + clist = query_array(config_registry, familyname) + for config in clist.split(): + # Look up the kernels for the current sub-configuration. + kernels = query_array(kernel_registry, config) + for kernel in kernels.split(): + # Add the sub-configuration to the list associated with the kernel. + # Query the current sub-configs for the current ${kernel}. + cur_configs = query_array(kconfig_registry, kernel) + # Add the current sub-configuration to the list of sub-configs we just queried. + if cur_configs and cur_configs.strip(): + cur_configs = " ".join([cur_configs, config]) + cur_configs = cur_configs.strip() + else: + cur_configs = config + newvalue = canonicalize_ws(cur_configs) + # Update the array. + assign_key_value(kconfig_registry, kernel, newvalue) + + +def lastWord(string): + # finding the index of last space + index = string.rfind(" ") + # last word + return string[index + 1:] + + + +config_blist = "" +indirect_blist = "" +config_registry = {} +kernel_registry = {} +kconfig_registry = {} + +def process_config(): + # Obtain the script name. + cwd = os.getcwd() + path, arch = os.path.split(sys.argv[1]) + target_file = os.path.join(sys.argv[2], 'config_registry') + + read_registry_file(target_file) + + config_list = query_array(config_registry, arch) + kernel_list = query_array(kernel_registry, arch) + + build_kconfig_registry(arch) + + config_list = " ".join(config_list.split()) + kernel_list = " ".join(kernel_list.split()) + + # We use a sorted version of kernel_list so that it ends up matching the + # display order of the kconfig_registry above. + kernel_list_sort = kernel_list + + kconfig_map = "" + for kernel in kernel_list_sort.split(): + configs = query_array(kconfig_registry, kernel) + + has_one_kernel = is_singleton(configs) + contains_kernel = is_in_list(kernel, configs) + + # Check if the list is a singleton. + if has_one_kernel: + reducedclist = configs + # Check if the list contains a sub-config name that matches the kernel. + elif contains_kernel: + reducedclist = kernel + # Otherwise, use the last name. + else: + last_config = lastWord(configs) + reducedclist = last_config + + # Create a new "kernel:subconfig" pair and add it to the kconfig_map + # list, removing whitespace. + new_pair = kernel+':'+reducedclist + kconfig_map = " ".join([kconfig_map, new_pair]) + kconfig_map = canonicalize_ws(kconfig_map) + + config = " ; ".join([config_list, kernel_list, kconfig_map]) + return config + + +# Function call for config family names +CONFIG = process_config() +print(CONFIG) diff --git a/build/cmake/subdir_helper_functions.cmake b/build/cmake/subdir_helper_functions.cmake new file mode 100644 index 0000000000..ad41a3001c --- /dev/null +++ b/build/cmake/subdir_helper_functions.cmake @@ -0,0 +1,122 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Create a list of keywords for files that need to be ignored by the system. +file(READ ${CMAKE_SOURCE_DIR}/build/gen-make-frags/ignore_list IGNORE_LIST) +string(REPLACE "\n" ";" IGNORE_LIST ${IGNORE_LIST}) + +# Create a list of suffixes for files that need to be compiled to create the library. +file(READ ${CMAKE_SOURCE_DIR}/build/gen-make-frags/suffix_list SUFFIX_LIST) +string(REPLACE "\n" ";" SUFFIX_LIST ${SUFFIX_LIST}) + +#-------------------------------------------- +# SUFFIX LISTS +#-------------------------------------------- +# Source suffixes. +set(CONFIG_SRC_SUFS "c") +set(KERNELS_SRC_SUFS "c;s;S") +set(FRAME_SRC_SUFS "c") + +set(AOCLDTL_SRC_SUFS "c") +set(ADDON_C99_SUFS "c") +set(ADDON_CXX_SUFS "cc;cpp;cxx") +set(ADDON_SRC_SUFS "${ADDON_C99_SUFS};${ADDON_CXX_SUFS}") + +set(SANDBOX_C99_SUFS "c") +set(SANDBOX_CXX_SUFS "cc;cpp;cxx") +set(SANDBOX_SRC_SUFS "${SANDBOX_C99_SUFS};${SANDBOX_CXX_SUFS}") + +# Header suffixes. +set(FRAME_HDR_SUFS "h") + +set(AOCLDTL_HDR_SUFS "h") +set(ADDON_H99_SUFS "h") +set(ADDON_HXX_SUFS "hh;hpp;hxx") +set(ADDON_HDR_SUFS "${ADDON_H99_SUFS};${ADDON_HXX_SUFS}") + +set(SANDBOX_H99_SUFS "h") +set(SANDBOX_HXX_SUFS "hh;hpp;hxx") +set(SANDBOX_HDR_SUFS "$(SANDBOX_H99_SUFS);$(SANDBOX_HXX_SUFS)") + +# Combine all header suffixes and remove duplicates. +set(ALL_HDR_SUFS "${FRAME_HDR_SUFS};${ADDON_HDR_SUFS};${SANDBOX_HDR_SUFS};${AOCLDTL_HDR_SUFS}") +list(REMOVE_DUPLICATES ALL_HDR_SUFS) + +set(ALL_H99_SUFS "${FRAME_HDR_SUFS};${ADDON_HDR_SUFS};${SANDBOX_H99_SUFS};${AOCLDTL_HDR_SUFS}") +list(REMOVE_DUPLICATES ALL_H99_SUFS) + +#-------------------------------------------- +# Important sets of header files and paths +#-------------------------------------------- +# Get a list of all sub-directories of a given directory +macro(get_dirpaths_with_suffixes result curdir sufflist) + set(dirlist "") + # dirlist will have all files which are below this directory. + file(GLOB_RECURSE children LIST_DIRECTORIES true ${curdir}/*) + # Adding current directory in the list. + list(PREPEND children ${curdir}) + # Filter out anything that is not a directory. + foreach(child ${children}) + if(IS_DIRECTORY ${child}) + set(HAS_SUFF_FILE "false") + foreach(suff ${sufflist}) + file(GLOB suff_files LIST_DIRECTORIES false ${child}/*\.${suff}) + list(LENGTH suff_files list_size) + if(NOT (${list_size} STREQUAL 0)) + set(HAS_SUFF_FILE "true") + # If there is at least one file with a specific suffix break from for-loop. + break() + endif() + endforeach() + # If there is at least one *.suff file, add directory path in the list. + if(HAS_SUFF_FILE STREQUAL "true") + list(APPEND dirlist "${child}/") + endif() + endif() + endforeach() + # Get the name of the current directory, after removing the source directory + # from the name, so that we can exclude the files that are part of the ignore + # list even if the blis directory is located in a directory with a name that + # would be ignored. + string(REPLACE "${CMAKE_SOURCE_DIR}/" "" curdirsimple ${curdir}) + # Filter out anything that is part of the IGNORE_LIST. + foreach(item ${IGNORE_LIST}) + list(FILTER dirlist EXCLUDE REGEX ${curdirsimple}.*/${item}/) + endforeach() + list(APPEND ${result} ${dirlist}) +endmacro() + +# Get a list of all source files of a given directory based on the suffix list. +# Returns a list which can be transfored to a string when needed +# from high level CMake. +macro(get_filepaths_with_suffixes result curdir sufflist) + set(sourcelist "") + # Get the name of the current directory, after removing the source directory + # from the name, so that we can exclude the files that are part of the ignore + # list even if the blis directory is located in a directory with a name that + # would be ignored. + string(REPLACE "${CMAKE_SOURCE_DIR}/" "" curdirsimple ${curdir}) + foreach(suff ${sufflist}) + # dirlist will have all files which are below this directory. + file(GLOB_RECURSE suff_files LIST_DIRECTORIES false ${curdir}/*\.${suff}) + # Filter out anything that is part of the IGNORE_LIST. + foreach(item ${IGNORE_LIST}) + list(FILTER suff_files EXCLUDE REGEX ${curdirsimple}.*/${item}/) + endforeach() + list(APPEND sourcelist "${suff_files}") + endforeach() + list(APPEND ${result} ${sourcelist}) +endmacro() + +# Choose correct sub-configurarion name for the given kernel set. +# Behaves similary to get-config-for-kset. +macro(get_config_for_kernel_from_kconfig_map config kernel kconfig_map) + set(conf ${kconfig_map}) + # Since kconfig_map has as elements pairs of the form kernel:config, + # to find the element with the corresponding config we need to filter + # with respect to the kernel first. + list(FILTER conf INCLUDE REGEX ${kernel}:) + # Now that the list has only one element, we can remove the part + # of kernel: and then we will be left with config. + list(TRANSFORM conf REPLACE ${kernel}: "") + list(APPEND ${config} ${conf}) +endmacro() diff --git a/build/detect/config/config_detect.c b/build/detect/config/config_detect.c index 2b59f78bf9..03dc9ce877 100644 --- a/build/detect/config/config_detect.c +++ b/build/detect/config/config_detect.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,12 +33,39 @@ */ -#define BLIS_INLINE static -#define BLIS_EXPORT_BLIS -#include "bli_system.h" -#include "bli_type_defs.h" -#include "bli_arch.h" -#include "bli_cpuid.h" +// NOTE: This file will likely only ever get compiled as part of the BLIS +// configure script, and therefore BLIS_CONFIGURETIME_CPUID is guaranteed to +// be #defined. However, we preserve the cpp conditional for consistency with +// the other three files mentioned above. +#ifdef BLIS_CONFIGURETIME_CPUID + + // NOTE: If you need to make any changes to this cpp branch, it's probably + // the case that you also need to modify bli_arch.c, bli_cpuid.c, and + // bli_env.c. Don't forget to update these other files as needed! + + // The BLIS_ENABLE_SYSTEM macro must be defined so that the correct cpp + // branch in bli_system.h is processed. (This macro is normally defined in + // bli_config.h.) + #define BLIS_ENABLE_SYSTEM + + // Use C-style static inline functions for any static inline functions that + // happen to be defined by the headers below. (This macro is normally defined + // in bli_config_macro_defs.h.) + #define BLIS_INLINE static + + // Since we're not building a shared library, we can forgo the use of the + // BLIS_EXPORT_BLIS annotations by #defining them to be nothing. (This macro + // is normally defined in bli_config_macro_defs.h.) + #define BLIS_EXPORT_BLIS + + #include "bli_system.h" + #include "bli_type_defs.h" + #include "bli_arch.h" + #include "bli_cpuid.h" + //#include "bli_env.h" +#else + #include "blis.h" +#endif int main( int argc, char** argv ) { diff --git a/build/detect/config/old/cpuid_x86.c b/build/detect/config/old/cpuid_x86.c index f4985e3914..3167b727a2 100644 --- a/build/detect/config/old/cpuid_x86.c +++ b/build/detect/config/old/cpuid_x86.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2015, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/gen-make-frags/ignore_list b/build/gen-make-frags/ignore_list index ccdd18f644..3a7afbd8bc 100644 --- a/build/gen-make-frags/ignore_list +++ b/build/gen-make-frags/ignore_list @@ -5,3 +5,4 @@ other temp tmp test +p10_testsuite diff --git a/build/irun.py b/build/irun.py index 429981603c..767011f272 100755 --- a/build/irun.py +++ b/build/irun.py @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2018, The University of Texas at Austin -# Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. +# Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/build/libblis-symbols.def b/build/libblis-symbols.def index e1bfce807e..97146a7861 100644 --- a/build/libblis-symbols.def +++ b/build/libblis-symbols.def @@ -1297,17 +1297,17 @@ bli_malloc_user bli_mbool_create bli_mbool_free bli_mbool_init -bli_membrk_acquire_m -bli_membrk_compute_pool_block_sizes -bli_membrk_compute_pool_block_sizes_dt -bli_membrk_finalize -bli_membrk_finalize_pools -bli_membrk_init -bli_membrk_init_pools -bli_membrk_pool_size -bli_membrk_query -bli_membrk_release -bli_membrk_rntm_set_membrk +bli_pba_acquire_m +bli_pba_compute_pool_block_sizes +bli_pba_compute_pool_block_sizes_dt +bli_pba_finalize +bli_pba_finalize_pools +bli_pba_init +bli_pba_init_pools +bli_pba_pool_size +bli_pba_query +bli_pba_release +bli_pba_rntm_set_pba bli_memsys_finalize bli_memsys_init bli_mkherm diff --git a/build/templates/license.c b/build/templates/license.c index 6505a70ffd..b076cb49e0 100644 --- a/build/templates/license.c +++ b/build/templates/license.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.h b/build/templates/license.h index 6505a70ffd..b076cb49e0 100644 --- a/build/templates/license.h +++ b/build/templates/license.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2019, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/build/templates/license.sh b/build/templates/license.sh index b9c51e2892..087da58353 100644 --- a/build/templates/license.sh +++ b/build/templates/license.sh @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2019, The University of Texas at Austin -# Copyright (C) 2018, Advanced Micro Devices, Inc. +# Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/common.mk b/common.mk index 220e8ccaa8..7f200545ed 100644 --- a/common.mk +++ b/common.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -245,7 +245,6 @@ files-that-dont-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f) # function. rm-dups = $(if $1,$(firstword $1) $(call rm-dups,$(filter-out $(firstword $1),$1))) - # # --- Include makefile configuration file -------------------------------------- # @@ -600,27 +599,35 @@ SOFLAGS += -Wl,-soname,$(LIBBLIS_SONAME) endif endif +# Decide whether to use static or shared library on Linux and OS X +MK_USE_LIB=static +ifeq ($(MK_ENABLE_STATIC),no) + MK_USE_LIB=shared +endif +ifeq ($(USE_SHARED),yes) + MK_USE_LIB=shared +endif + # Decide which library to link to for things like the testsuite and BLIS test # drivers. We default to the static library, unless only the shared library was # enabled, in which case we use the shared library. LIBBLIS_L := $(LIBBLIS_A) LIBBLIS_LINK := $(LIBBLIS_A_PATH) ifeq ($(MK_ENABLE_SHARED),yes) -ifeq ($(MK_ENABLE_STATIC),no) -LIBBLIS_L := $(LIBBLIS_SO) -LIBBLIS_LINK := $(LIBBLIS_SO_PATH) -ifeq ($(IS_WIN),no) -# For Linux and OS X: set rpath property of shared object. -LDFLAGS += -Wl,-rpath,$(BASE_LIB_PATH) + ifeq ($(MK_USE_LIB),shared) + LIBBLIS_L := $(LIBBLIS_SO) + LIBBLIS_LINK := $(LIBBLIS_SO_PATH) + ifeq ($(IS_WIN),no) + # For Linux and OS X: set rpath property of shared object. + LDFLAGS += -Wl,-rpath,$(BASE_LIB_PATH) + endif + endif + # On windows, use the shared library even if static is created. + ifeq ($(IS_WIN),yes) + LIBBLIS_L := $(LIBBLIS_SO) + LIBBLIS_LINK := $(LIBBLIS_SO_PATH) + endif endif -endif -# On windows, use the shared library even if static is created. -ifeq ($(IS_WIN),yes) -LIBBLIS_L := $(LIBBLIS_SO) -LIBBLIS_LINK := $(LIBBLIS_SO_PATH) -endif -endif - # # --- Include makefile definitions file ---------------------------------------- @@ -692,7 +699,7 @@ endif # Disable tautological comparision warnings in clang. ifeq ($(CC_VENDOR),clang) -CWARNFLAGS += -Wno-tautological-compare +CWARNFLAGS += -Wno-tautological-compare -Wno-pass-failed endif $(foreach c, $(CONFIG_LIST_FAM), $(eval $(call append-var-for,CWARNFLAGS,$(c)))) diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index 3a5925a306..b23fb85a4e 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,28 +1,187 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. ## +##Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc ## -if(${TARGET_ARCH} STREQUAL zen4) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(zen4) -elseif(${TARGET_ARCH} STREQUAL zen3) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(zen3) -elseif(${TARGET_ARCH} STREQUAL zen2) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(zen2) -elseif(${TARGET_ARCH} STREQUAL zen) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(zen) -elseif(${TARGET_ARCH} STREQUAL amdzen) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(generic) -add_subdirectory(zen) -add_subdirectory(zen2) -add_subdirectory(zen3) -add_subdirectory(zen4) -elseif(${TARGET_ARCH} STREQUAL haswell) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(haswell) -else(${TARGET_ARCH} STREQUAL generic) -message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(generic) -endif() +# Writing a function that will be used to generate the required object +# libraries for the required configs. +function(generate_config_targets config_target) + # Collect all subdirectory paths that have at least one file with suffix in CONFIG_SRC_SUFS list. + get_filepaths_with_suffixes(LOCAL_SOURCE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${config_target}" "${CONFIG_SRC_SUFS}") + + # Create an object library using the source file list above. + add_library(${config_target}_CONFIG + OBJECT + ${LOCAL_SOURCE_FILES} + ) + # Include the corresponding make_defs.cmake that holds the required compiler options. + include(${CMAKE_SOURCE_DIR}/config/${config_target}/make_defs.cmake) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-config-cflags-for + target_compile_options(${config_target}_CONFIG + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-config-cflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${config_target}_CONFIG + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-config-cflags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${config_target}_CONFIG + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${config_target}_CONFIG PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${config_target}_CONFIG PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_CONFIG PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${config_target}_CONFIG flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${config_target}_CONFIG PROPERTIES FOLDER object-libs-targets) + + # Create on object library using the corresponding reference kernel initialization file. + add_library(${config_target}_REFINIT + OBJECT + ${CMAKE_SOURCE_DIR}/ref_kernels/bli_cntx_ref.c + ) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-refinit-cflags-for + target_compile_options(${config_target}_REFINIT + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-refinit-cflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${config_target}_REFINIT + PRIVATE + # get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-refinit-cflags-for + ${BUILD_CPPFLAGS} + # get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-refinit-cflags-for + -DBLIS_CNAME=${config_target} + ) + target_include_directories(${config_target}_REFINIT + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${config_target}_REFINIT PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${config_target}_REFINIT PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_REFINIT PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${config_target}_REFINIT flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${config_target}_REFINIT PROPERTIES FOLDER object-libs-targets) + + # Collect all subdirectory paths that have at least one file with suffix in KERNELS_SRC_SUFS list. + set(REFKERN_PATH ${CMAKE_SOURCE_DIR}/ref_kernels) + get_filepaths_with_suffixes(LOCAL_REFKERN_FILES ${REFKERN_PATH} ${KERNELS_SRC_SUFS}) + # Remove bli_cntx_ref.c from source list. + list(FILTER LOCAL_REFKERN_FILES EXCLUDE REGEX bli_cntx_ref.c) + + # Create on object library using the corresponding reference implementations being targeted. + add_library(${config_target}_REFKERN + OBJECT + ${LOCAL_REFKERN_FILES} + ) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-refkern-cflags-for + target_compile_options(${config_target}_REFKERN + PRIVATE + # load-var-for,CROPTFLAGS + ${CROPTFLAGS} + # load-var-for,CRVECFLAGS + ${CRVECFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-refkernel-cflags-for + ${COMPSIMDFLAGS} + # in get-refkern-cflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${config_target}_REFKERN + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-refkern-cflags-for + -DBLIS_CNAME=${config_target} + # in get-refkern-cflags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${config_target}_REFKERN + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${config_target}_REFKERN PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${config_target}_REFKERN PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${config_target}_REFKERN PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${config_target}_REFKERN flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${config_target}_REFKERN PROPERTIES FOLDER object-libs-targets) +endfunction() + +# Generate targets for each of the configs. +foreach(CONF ${CONFIG_LIST}) + generate_config_targets(${CONF}) +endforeach() diff --git a/config/a64fx/bli_a64fx_sector_cache.h b/config/a64fx/bli_a64fx_sector_cache.h new file mode 100644 index 0000000000..a81d04caca --- /dev/null +++ b/config/a64fx/bli_a64fx_sector_cache.h @@ -0,0 +1,117 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + // A64FX: set up cache sizes + // + // Reference: A64FX (TM) specification Fujitsu HPC Extension + // Link: https://github.com/fujitsu/A64FX/blob/master/doc/A64FX_Specification_HPC_Extension_v1_EN.pdf + // + // 63:15 | 14:12 | 11 | 10:08 | 07 | 06:04 | 03 | 02:00 | + // RES0 | l1_sec3_max | RES0 | l1_sec2_max | RES0 | l1_sec1_max | RES0 | l1_sec0_max | + // + // the bits set number of maximum sectors from 0-7 + // 000 - 0 + // 001 - 1 + // 010 - 2 + // 011 - 3 + // 100 - 4 + // 101 - 5 + // 110 - 6 + // 111 - 7 + // + // For L1 we want to maximize the number of sectors for B + // Configuration 1: 1 sector for C (sector 3) + // 1 sector for A (sector 1) + // 6 sectors for B (sector 2) + // 0 sectors for the rest (sector 0) + // + // 16b bitfield conf. 1: 0b0 001 0 110 0 001 0 000 + // + // Configuration 2: 1 sector for C (sector 3) + // 1 sector for A (sector 1) + // 5 sectors for B (sector 2) + // 1 sectors for the rest (sector 0) + // + // 16b bitfield conf. 2: 0b0 001 0 101 0 001 0 001 + // + // accessing the control register: + // + // MRS , S3_3_C11_C8_2 + // MSR S3_3_C11_C8_2, + // + // TODO: First tests showed no change in performance, a deeper investigation + // is necessary +#define A64FX_SETUP_SECTOR_CACHE_SIZES(config_bitfield)\ +{\ + uint64_t sector_cache_config = config_bitfield;\ + __asm__ volatile(\ + "msr s3_3_c11_c8_2,%[sector_cache_config]"\ + :\ + : [sector_cache_config] "r" (sector_cache_config)\ + :\ + );\ +} + +#define A64FX_SETUP_SECTOR_CACHE_SIZES_L2(config_bitfield)\ +{\ + uint64_t sector_cache_config = config_bitfield;\ + __asm__ volatile(\ + "msr s3_3_c15_c8_2,%[sector_cache_config]"\ + :\ + : [sector_cache_config] "r" (sector_cache_config)\ + :\ + );\ +} + + +#define A64FX_SET_CACHE_SECTOR(areg, tag, sparereg)\ +" mov "#sparereg", "#tag" \n\t"\ +" lsl "#sparereg", "#sparereg", 56 \n\t"\ +" orr "#areg", "#areg", "#sparereg" \n\t" + +#define A64FX_READ_SECTOR_CACHE_SIZES(output_uint64)\ +__asm__ volatile(\ + "mrs %["#output_uint64"],s3_3_c11_c8_2"\ + : [output_uint64] "=r" (output_uint64)\ + : \ + :\ + ); + +#define A64FX_SCC(sec0,sec1,sec2,sec3)\ + (uint64_t)((sec0 & 0x7LU) | ((sec1 & 0x7LU) << 4) | ((sec2 & 0x7LU) << 8) | ((sec3 & 0x7LU) << 12)) + +#define A64FX_SCC_L2(sec02,sec13)\ + (uint64_t)((sec02 & 0x1FLU) | ((sec13 & 0x1FLU) << 8)) + diff --git a/config/a64fx/bli_cntx_init_a64fx.c b/config/a64fx/bli_cntx_init_a64fx.c new file mode 100644 index 0000000000..5061570f80 --- /dev/null +++ b/config/a64fx/bli_cntx_init_a64fx.c @@ -0,0 +1,151 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_a64fx_sector_cache.h" + +void bli_cntx_init_a64fx( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_a64fx_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 2, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + cntx + ); + + // Set SVE-512 packing routine. + bli_cntx_set_packm_kers + ( + 3, + BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, + BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_12xk, + BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 10, 10, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 256, 128, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 2048, 2048, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 23040, 26880, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + +#if 0 + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 65, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 65, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 65, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 4, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, 10, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, 16, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 120, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 4080, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +#endif + + // Set A64FX cache sector sizes for each PE/CMG + // SC Fugaku might disable users' setting cache sizes. +#if !defined(CACHE_SECTOR_SIZE_READONLY) +#pragma omp parallel + { + A64FX_SETUP_SECTOR_CACHE_SIZES(A64FX_SCC(0,1,3,0)) + A64FX_SETUP_SECTOR_CACHE_SIZES_L2(A64FX_SCC_L2(9,28)) + } +#endif + +} + diff --git a/sandbox/power10/generic_gemm.h b/config/a64fx/bli_family_a64fx.h similarity index 67% rename from sandbox/power10/generic_gemm.h rename to config/a64fx/bli_family_a64fx.h index 8b1a16dc9f..5e3f29fd4b 100644 --- a/sandbox/power10/generic_gemm.h +++ b/config/a64fx/bli_family_a64fx.h @@ -32,27 +32,15 @@ */ -// Prototypes and template for the 5-loop gemm algorithm - -#include "bli_sandbox.h" - -#define GEMM_PASTEMAC_(ch) bli_ ## ch ## gemm_ -#define GEMM_PASTEMAC(ch) GEMM_PASTEMAC_(ch) - -#define GENERIC_GEMM_PROTO(ch, DTYPE_IN, DTYPE_OUT) \ -void GEMM_PASTEMAC(ch) \ - ( \ - dim_t MR, dim_t NR, dim_t KC, dim_t NC, dim_t MC, \ - int m, int n, int k, \ - DTYPE_IN* restrict A, int rs_a, int cs_a, int A_align, \ - DTYPE_IN* restrict B, int rs_b, int cs_b, int B_align, \ - DTYPE_OUT* restrict C, int rs_c, int cs_c, \ - DTYPE_OUT* alpha, DTYPE_OUT* beta \ - ) - -GENERIC_GEMM_PROTO( sb, bfloat16, float); -GENERIC_GEMM_PROTO( sh, float16, float); -GENERIC_GEMM_PROTO(i16, int16_t, int32_t); -GENERIC_GEMM_PROTO( i8, int8_t, int32_t); -GENERIC_GEMM_PROTO( i4, nibbles, int32_t); +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 256 +#define BLIS_SIMD_NUM_REGISTERS 32 + + +//#endif diff --git a/config/a64fx/make_defs.mk b/config/a64fx/make_defs.mk new file mode 100644 index 0000000000..d6871fac31 --- /dev/null +++ b/config/a64fx/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := a64fx +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE -D_A64FX +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 -ftree-vectorize -march=armv8-a+sve +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +CKVECFLAGS := + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/amd64_legacy/bli_family_amd64_legacy.h b/config/amd64_legacy/bli_family_amd64_legacy.h index 5629b9a2d3..c13a506346 100644 --- a/config/amd64_legacy/bli_family_amd64_legacy.h +++ b/config/amd64_legacy/bli_family_amd64_legacy.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/config/amd64_legacy/make_defs.mk b/config/amd64_legacy/make_defs.mk index 5f0d613cbb..a8344f7072 100644 --- a/config/amd64_legacy/make_defs.mk +++ b/config/amd64_legacy/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc +# Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/amdzen/bli_family_amdzen.h b/config/amdzen/bli_family_amdzen.h index 7e4d460d13..e22cd18ccf 100644 --- a/config/amdzen/bli_family_amdzen.h +++ b/config/amdzen/bli_family_amdzen.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,53 +37,23 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops -// to be not paralleized. -// +// to be not parallelized. #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 - #define BLIS_ENABLE_SMALL_MATRIX #define BLIS_ENABLE_SMALL_MATRIX_TRSM - // This will select the threshold below which small matrix code will be called. #define BLIS_SMALL_MATRIX_THRES 700 #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 -#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 -#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -/* - * Override the block sizes in the context to the block sizes used - * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default - * GEMM kernels are AVX512 based and uses different block sizes. - * - * This function should be called in TRSM path before performing - * any packing operations. - * - * Also the context must be restored to default values by calling - * bli_zen4_restore_default_blkszs() before exiting TRSM Path - */ -BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); - -BLIS_EXPORT_BLIS void bli_zen4_override_gemmt_blkszs (cntx_t* cntx); - -/* - * Restore the block sizes to default values needed for zen4 context. - * - * This function should be called to restore the block sizes to there - * default values if they where overriden by calling - * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the - * TRSM path. - * - */ -BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); - #endif - diff --git a/config/amdzen/make_defs.cmake b/config/amdzen/make_defs.cmake new file mode 100644 index 0000000000..ac7d1b506e --- /dev/null +++ b/config/amdzen/make_defs.cmake @@ -0,0 +1,24 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# For architecture independent files we still need to define +# the required flags. +if(MSVC) + if(NOT ("${CMAKE_BUILD_TYPE}" MATCHES "Release")) + set(CDBGFLAGS /Zo) + endif() + if("${CMAKE_BUILD_TYPE}" MATCHES "Debug") + set(COPTFLAGS /Od) + else() # Release or RelWithDebInfo + set(COPTFLAGS /O2) + endif() +else() + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() diff --git a/config/armsve/bli_armsve_config_utils.c b/config/armsve/bli_armsve_config_utils.c new file mode 100644 index 0000000000..fdddeebabe --- /dev/null +++ b/config/armsve/bli_armsve_config_utils.c @@ -0,0 +1,92 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +dim_t bli_vl_bits_armsve(void) +{ \ + uint64_t vl = 0; + __asm__ ( + " mov x0, xzr \n\t" + " incb x0 \n\t" + " mov %[vl], x0 \n\t" + : [vl] "=r" (vl) + : + : "x0" + ); + return vl; +} + + +#define EXPANDMAC_BLKSZ_ARMSVE(ch, S_Data) \ +void PASTEMAC(ch, _blksz_armsve) (dim_t *m_r_, dim_t *n_r_, \ + dim_t *k_c_, dim_t *m_c_, dim_t *n_c_) \ +{ \ + dim_t W_L1 = bli_env_get_var("BLIS_SVE_W_L1", W_L1_SVE_DEFAULT); \ + dim_t N_L1 = bli_env_get_var("BLIS_SVE_N_L1", N_L1_SVE_DEFAULT); \ + dim_t C_L1 = bli_env_get_var("BLIS_SVE_C_L1", C_L1_SVE_DEFAULT); \ + dim_t W_L2 = bli_env_get_var("BLIS_SVE_W_L2", W_L2_SVE_DEFAULT); \ + dim_t N_L2 = bli_env_get_var("BLIS_SVE_N_L2", N_L2_SVE_DEFAULT); \ + dim_t C_L2 = bli_env_get_var("BLIS_SVE_C_L2", C_L2_SVE_DEFAULT); \ + dim_t W_L3 = bli_env_get_var("BLIS_SVE_W_L3", W_L3_SVE_DEFAULT); \ + dim_t N_L3 = bli_env_get_var("BLIS_SVE_N_L3", N_L3_SVE_DEFAULT); \ + dim_t C_L3 = bli_env_get_var("BLIS_SVE_C_L3", C_L3_SVE_DEFAULT); \ +\ + dim_t vl_b = bli_vl_bits_armsve(); \ + dim_t vl = vl_b / S_Data; \ + dim_t m_r = 2 * vl; \ + dim_t n_r = 10; \ +\ + dim_t k_c = (dim_t)( floor((W_L1 - 1.0)/(1.0 + (double)n_r/m_r)) * N_L1 * C_L1 ) \ + / (n_r * S_Data); \ +\ + dim_t C_Ac = W_L2 - 1 - ceil( (2.0 * k_c * n_r * S_Data)/(C_L2 * N_L2) ); \ + dim_t m_c = C_Ac * (N_L2 * C_L2)/(k_c * S_Data); \ + m_c -= m_c % m_r; \ +\ + dim_t C_Bc = W_L3 - 1 - ceil( (2.0 * k_c * m_c * S_Data)/(C_L3 * N_L3) ); \ + dim_t n_c = C_Bc * (N_L3 * C_L3)/(k_c * S_Data); \ + n_c -= n_c % n_r; \ +\ + *m_r_ = m_r; \ + *n_r_ = n_r; \ + *k_c_ = k_c; \ + *m_c_ = m_c; \ + *n_c_ = n_c; \ +} + +EXPANDMAC_BLKSZ_ARMSVE( s, 4 ) +EXPANDMAC_BLKSZ_ARMSVE( d, 8 ) + diff --git a/config/armsve/bli_armsve_config_utils.h b/config/armsve/bli_armsve_config_utils.h new file mode 100644 index 0000000000..07aa9ba7d2 --- /dev/null +++ b/config/armsve/bli_armsve_config_utils.h @@ -0,0 +1,42 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +dim_t bli_vl_bits_armsve(void); + +void bli_s_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); +void bli_d_blksz_armsve(dim_t *m_r_, dim_t *n_r_, dim_t *k_c_, dim_t *m_c_, dim_t *n_c_); + diff --git a/config/armsve/bli_cntx_init_armsve.c b/config/armsve/bli_cntx_init_armsve.c new file mode 100644 index 0000000000..434979f915 --- /dev/null +++ b/config/armsve/bli_cntx_init_armsve.c @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_armsve_config_utils.h" + +void bli_cntx_init_armsve( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; +#if 0 + blksz_t thresh[ BLIS_NUM_THRESH ]; +#endif + + // Set default kernel blocksizes and functions. + bli_cntx_init_armsve_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Block size. + dim_t m_r_s, n_r_s, k_c_s, m_c_s, n_c_s; + dim_t m_r_d, n_r_d, k_c_d, m_c_d, n_c_d; + bli_s_blksz_armsve(&m_r_s, &n_r_s, &k_c_s, &m_c_s, &n_c_s); + bli_d_blksz_armsve(&m_r_d, &n_r_d, &k_c_d, &m_c_d, &n_c_d); + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 2, + // These are vector-length agnostic kernels. Yet knowing mr is required at runtime. + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armsve_asm_2vx10_unindexed, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armsve_asm_2vx10_unindexed, FALSE, + cntx + ); + + // Set VL-specific packing routines if applicable. + if (m_r_d==16) + bli_cntx_set_packm_kers + ( + 3, + BLIS_PACKM_10XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_10xk, + BLIS_PACKM_12XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_12xk, + BLIS_PACKM_16XK_KER, BLIS_DOUBLE, bli_dpackm_armsve512_asm_16xk, + cntx + ); + else if (m_r_d==8) + bli_cntx_set_packm_kers + ( + 1, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_armsve256_asm_8xk, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], m_r_s, m_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], n_r_s, n_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], m_c_s, m_c_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], k_c_s, k_c_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], n_c_s, n_c_d, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); + +#if 0 + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], -1, 101, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], -1, 101, -1, -1 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], -1, 101, -1, -1 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 4, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_armsve_10x2v_unindexed, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], -1, n_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], -1, m_r_d, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], -1, 120, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], -1, 256, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], -1, 2048, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +#endif +} + diff --git a/config/armsve/bli_family_armsve.h b/config/armsve/bli_family_armsve.h new file mode 100644 index 0000000000..b67ae7c606 --- /dev/null +++ b/config/armsve/bli_family_armsve.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//#ifndef BLIS_FAMILY_H +//#define BLIS_FAMILY_H + + +// -- MEMORY ALLOCATION -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 256 +#define BLIS_SIMD_NUM_REGISTERS 32 + +// SVE-specific configs. +#define N_L1_SVE_DEFAULT 64 +#define W_L1_SVE_DEFAULT 4 +#define C_L1_SVE_DEFAULT 256 +#define N_L2_SVE_DEFAULT 2048 +#define W_L2_SVE_DEFAULT 16 +#define C_L2_SVE_DEFAULT 256 +#define N_L3_SVE_DEFAULT 8192 +#define W_L3_SVE_DEFAULT 16 +#define C_L3_SVE_DEFAULT 256 + +//#endif + diff --git a/config/armsve/make_defs.mk b/config/armsve/make_defs.mk new file mode 100644 index 0000000000..d3495efbb8 --- /dev/null +++ b/config/armsve/make_defs.mk @@ -0,0 +1,82 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2014, The University of Texas at Austin +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := armsve +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. +CPPROCFLAGS := -D_GNU_SOURCE +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 -ftree-vectorize -march=armv8-a+sve +endif + +# Flags specific to optimized kernels. +CKOPTFLAGS := $(COPTFLAGS) +CKVECFLAGS := + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +ifeq ($(CC_VENDOR),gcc) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +ifeq ($(CC_VENDOR),clang) +CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +else +CRVECFLAGS := $(CKVECFLAGS) +endif +endif + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config/generic/CMakeLists.txt b/config/generic/CMakeLists.txt deleted file mode 100644 index 2fd3855574..0000000000 --- a/config/generic/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## - -target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_generic.c - ) diff --git a/config/generic/make_defs.cmake b/config/generic/make_defs.cmake new file mode 100644 index 0000000000..d99d08e691 --- /dev/null +++ b/config/generic/make_defs.cmake @@ -0,0 +1,40 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to optimized kernels. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS}) +else() + set(CKOPTFLAGS ${COPTFLAGS} -O3) +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + # Placeholder in case we want to add gcc-specific flags. +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "icc") + # Placeholder in case we want to add icc-specific flags. +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # Placeholder in case we want to add clang-specific flags. +else() + message(FATAL_ERROR "gcc, icc, or clang is required for this configuration.") +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CRVECFLAGS ${CKVECFLAGS}) +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CRVECFLAGS ${CKVECFLAGS}) +else() + set(CRVECFLAGS ${CKVECFLAGS}) +endif() diff --git a/config/haswell/CMakeLists.txt b/config/haswell/CMakeLists.txt deleted file mode 100644 index a43bfe2b23..0000000000 --- a/config/haswell/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. ## - -set(FILES - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_haswell.c - ) - -set(SUBDIRECTORIES "") -set(RELATIVE_PATH "haswell") - -#Add all subdirectories -foreach(VAR ${SUBDIRECTORIES}) - add_subdirectory(${VAR}) -endforeach() - -if(FILES) - #Add source files to target - target_sources("${PROJECT_NAME}" PRIVATE ${FILES}) - - #Install our source files - install(FILES ${FILES} DESTINATION ${RELATIVE_PATH}) -endif() diff --git a/config/haswell/bli_cntx_init_haswell.c b/config/haswell/bli_cntx_init_haswell.c index b4d8ba8b50..19608fa74e 100644 --- a/config/haswell/bli_cntx_init_haswell.c +++ b/config/haswell/bli_cntx_init_haswell.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/config/haswell/bli_family_haswell.h b/config/haswell/bli_family_haswell.h index 58154692a7..5be492e562 100644 --- a/config/haswell/bli_family_haswell.h +++ b/config/haswell/bli_family_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/config/old/haswellbb/bli_cntx_init_haswell.c b/config/old/haswellbb/bli_cntx_init_haswell.c index 9e1d03503a..2de20b96e2 100644 --- a/config/old/haswellbb/bli_cntx_init_haswell.c +++ b/config/old/haswellbb/bli_cntx_init_haswell.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/config/old/haswellbb/bli_family_haswell.h b/config/old/haswellbb/bli_family_haswell.h index 06dfdfcfcc..ed9c344931 100644 --- a/config/old/haswellbb/bli_family_haswell.h +++ b/config/old/haswellbb/bli_family_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/config/skx/bli_cntx_init_skx.c b/config/skx/bli_cntx_init_skx.c index f18503a7a7..91dd7e444f 100644 --- a/config/skx/bli_cntx_init_skx.c +++ b/config/skx/bli_cntx_init_skx.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -128,5 +129,69 @@ void bli_cntx_init_skx( cntx_t* cntx ) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); + + bli_cntx_set_l3_sup_kers + ( + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64m_avx512, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64n_avx512, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 24, 3, 12, + 6, 9, 3, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + } diff --git a/config/skx/make_defs.mk b/config/skx/make_defs.mk index 2db79a1f22..bba0363b72 100644 --- a/config/skx/make_defs.mk +++ b/config/skx/make_defs.mk @@ -72,7 +72,15 @@ ifeq ($(CC_VENDOR),icc) CKVECFLAGS := -xCORE-AVX512 else ifeq ($(CC_VENDOR),clang) +# NOTE: We have to use -march=haswell on Windows because apparently AVX512 +# uses an alternate calling convention where xmm registers are not callee-saved +# on the stack. When this is mixed with framework code compiled for general +# x86_64 mode then chaos ensues (e.g. #514). +ifeq ($(IS_WIN),yes) +CKVECFLAGS := -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -march=haswell +else CKVECFLAGS := -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -march=skylake-avx512 +endif else $(error gcc, icc, or clang is required for this configuration.) endif @@ -98,7 +106,15 @@ ifeq ($(CC_VENDOR),icc) CRVECFLAGS := -xCORE-AVX2 else ifeq ($(CC_VENDOR),clang) +# NOTE: We have to use -march=haswell on Windows because apparently AVX512 +# uses an alternate calling convention where xmm registers are not callee-saved +# on the stack. When this is mixed with framework code compiled for general +# x86_64 mode then chaos ensues (e.g. #514). +ifeq ($(IS_WIN),yes) +CRVECFLAGS := -march=haswell -funsafe-math-optimizations -ffp-contract=fast +else CRVECFLAGS := -march=skylake-avx512 -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast +endif else $(error gcc, icc, or clang is required for this configuration.) endif diff --git a/config/thunderx2/make_defs.mk b/config/thunderx2/make_defs.mk index 1fd1721c52..b43fea87c5 100644 --- a/config/thunderx2/make_defs.mk +++ b/config/thunderx2/make_defs.mk @@ -65,7 +65,11 @@ CKOPTFLAGS := $(COPTFLAGS) -O3 -ftree-vectorize ifeq ($(CC_VENDOR),gcc) CKVECFLAGS := -mcpu=thunderx2t99 else -$(error gcc is required for this configuration.) +ifeq ($(CC_VENDOR),clang) +CKVECFLAGS := -mcpu=thunderx2t99 +else +$(error gcc or clang is required for this configuration.) +endif endif # Flags specific to reference kernels. diff --git a/config/x86_64/bli_family_x86_64.h b/config/x86_64/bli_family_x86_64.h index 21b44db870..c327a0b19a 100644 --- a/config/x86_64/bli_family_x86_64.h +++ b/config/x86_64/bli_family_x86_64.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,10 +33,30 @@ */ -//#ifndef BLIS_FAMILY_H -//#define BLIS_FAMILY_H +#ifndef BLIS_FAMILY_H +#define BLIS_FAMILY_H +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not parallelized. +// +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM -//#endif +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 + +// When running HPL with pure MPI without DGEMM threading (Single-threaded +// BLIS), defining this macro as 1 yields better performance. +#define AOCL_BLIS_MULTIINSTANCE 0 + +#endif diff --git a/config/zen/CMakeLists.txt b/config/zen/CMakeLists.txt deleted file mode 100644 index 371f63b21c..0000000000 --- a/config/zen/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. ## - -target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen.c - ) diff --git a/config/zen/amd_config.cmake b/config/zen/amd_config.cmake new file mode 100644 index 0000000000..df3284d8fb --- /dev/null +++ b/config/zen/amd_config.cmake @@ -0,0 +1,49 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O2 -fomit-frame-pointer) + endif() +endif() + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(COPTFLAGS /Oy) + set(CKOPTFLAGS ${COPTFLAGS}) +else() + set(CKOPTFLAGS ${COPTFLAGS} -O3) +endif() + +if(MSVC) + set(CKVECFLAGS -mavx2 -mfma -mno-fma4 -mno-tbm -mno-xop -mno-lwp) +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CKVECFLAGS -mavx2 -mfpmath=sse -mfma) +elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + set(CKVECFLAGS -mavx2 -mfpmath=sse -mfma -mno-fma4 -mno-tbm -mno-xop -mno-lwp) + execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") + string(REGEX MATCHALL "(AOCC.LLVM)" CLANG_STRING "${CLANG_VERSION_STRING}") + if("${CLANG_STRING}" MATCHES "(AOCC.LLVM)") + list(APPEND CKVECFLAGS -mllvm -disable-licm-vrp) + endif() +else() + message(FATAL_ERROR "gcc or clang are required for this configuration.") +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CRVECFLAGS ${CKVECFLAGS} -funsafe-math-optimizations -ffp-contract=fast) +elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + set(CRVECFLAGS ${CKVECFLAGS} -funsafe-math-optimizations -ffp-contract=fast) +else() + set(CRVECFLAGS ${CKVECFLAGS}) +endif() diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 83ce2cf8b6..d88ea7577e 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -35,283 +35,289 @@ #include "blis.h" -//GEMMSUP_KER_PROT( double, d, gemmsup_r_haswell_ref ) - void bli_cntx_init_zen( cntx_t* cntx ) { - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - blksz_t thresh[ BLIS_NUM_THRESH ]; - - // Set default kernel blocksizes and functions. - bli_cntx_init_zen_ref( cntx ); - - // ------------------------------------------------------------------------- - - // Update the context with optimized native gemm micro-kernels and - // their storage preferences. - bli_cntx_set_l3_nat_ukrs - ( - 8, - // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, - // gemmtrsm_l - BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, - // gemmtrsm_u - BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, - cntx - ); - - // Update the context with architecture specific threshold functions - bli_cntx_set_l3_thresh_funcs - ( - 2, - // GEMMT - BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, - // SYRK - BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, - cntx - ); - - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 12, - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, - // dotxaxpyf - BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, - BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, - BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, - //axpy2v - BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, - BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, - cntx - ); - - // Update the context with optimized level-1v kernels. - bli_cntx_set_l1v_kers - ( - 29, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - - // axpbyv - BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, - BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, - - // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, - BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, - - // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, - BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, - BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, - - // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, - // scalv - - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, - - // swapv - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - - // copyv - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, - - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - - // scal2v - BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, - cntx - ); - - // Initialize level-3 blocksize objects with architecture-specific values. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 2, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 12, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 29, + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, + + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, + + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); /* - Multi Instance performance improvement of DGEMM when binded to a CCX - In Multi instance each thread runs a sequential DGEMM. + Multi Instance performance improvement of DGEMM when binded to a CCX + In Multi instance each thread runs a sequential DGEMM. - a) If BLIS is run in a multi-instance mode with - CPU freq 2.6/2.2 Ghz - DDR4 clock frequency 2400Mhz - mc = 240, kc = 512, and nc = 2040 - has better performance on EPYC server, over the default block sizes. + a) If BLIS is run in a multi-instance mode with + CPU freq 2.6/2.2 Ghz + DDR4 clock frequency 2400Mhz + mc = 240, kc = 512, and nc = 2040 + has better performance on EPYC server, over the default block sizes. - b) If BLIS is run in Single Instance mode - mc = 510, kc = 1024 and nc = 4080 + b) If BLIS is run in Single Instance mode + mc = 510, kc = 1024 and nc = 4080 */ + // Initialize level-3 blocksize objects with architecture-specific values. #ifdef BLIS_ENABLE_ZEN_BLOCK_SIZES - #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES + #if BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES + + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 1024, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 510, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 1024, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 4080 ); + #else - #else - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 4080 ); - #endif + #endif #else - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 3056 ); + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 4080, 3056 ); #endif - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 7, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - // level-1f - BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, - BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, - cntx - ); - - // Update the context with the current architecture's register and cache - // blocksizes for level-3 TRSM execution. - bli_cntx_set_trsm_blkszs - ( - 5, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); - - // ------------------------------------------------------------------------- - - // Initialize sup thresholds with architecture-appropriate values. - // s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 128 ); - - // Initialize the context with the sup thresholds. - bli_cntx_set_l3_sup_thresh - ( - 3, - BLIS_MT, &thresh[ BLIS_MT ], - BLIS_NT, &thresh[ BLIS_NT ], - BLIS_KT, &thresh[ BLIS_KT ], - cntx - ); - - // Initialize the context with the sup handlers. - bli_cntx_set_l3_sup_handlers - ( - 1, - BLIS_GEMM, bli_gemmsup_ref, - cntx - ); - - // Update the context with optimized small/unpacked gemm kernels. - bli_cntx_set_l3_sup_kers - ( - 30, - //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - cntx - ); - - // Initialize level-3 sup blocksize objects with architecture-specific - // values. - // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); - - // Update the context with the current architecture's register and cache - // blocksizes for small/unpacked level-3 problems. - bli_cntx_set_l3_sup_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); -} + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 128 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 1, + BLIS_GEMM, bli_gemmsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 8b31c32ca0..b833a11d1b 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,7 +38,7 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops -// to be not paralleized. +// to be not parallelized. #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -50,7 +50,7 @@ #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 -#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 -#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 #endif diff --git a/config/zen/make_defs.cmake b/config/zen/make_defs.cmake new file mode 100644 index 0000000000..682434bf52 --- /dev/null +++ b/config/zen/make_defs.cmake @@ -0,0 +1,39 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# Include file containing common flags for all AMD architectures +include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS} /Oy) +else() + set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + list(APPEND CKVECFLAGS -march=znver1) + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + endif() +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CRVECFLAGS ${CKVECFLAGS}) +else() + set(CRVECFLAGS ${CKVECFLAGS}) +endif() diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index 59fc7b0a67..4e8896bfb2 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/zen/old/bli_kernel.h b/config/zen/old/bli_kernel.h index cd324fd9a7..ab2656f5a8 100644 --- a/config/zen/old/bli_kernel.h +++ b/config/zen/old/bli_kernel.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/config/zen2/CMakeLists.txt b/config/zen2/CMakeLists.txt deleted file mode 100644 index c3cdc45c08..0000000000 --- a/config/zen2/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. ## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen2.c - ) diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 42eae35d95..c7d8137329 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -1,7 +1,9 @@ /* + BLIS An object-based framework for developing high-performance BLAS-like libraries. + Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. @@ -16,6 +18,7 @@ - Neither the name(s) of the copyright holder(s) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -27,6 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ #include "blis.h" @@ -46,7 +50,6 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l3_nat_ukrs ( 8, - // gemm BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, @@ -56,7 +59,6 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, - // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, @@ -67,12 +69,12 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l3_thresh_funcs ( 2, - //gemmt + // GEMMT BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, - //SYRK - BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, cntx - ); + ); // Update the context with optimized packm kernels. bli_cntx_set_packm_kers @@ -94,20 +96,20 @@ void bli_cntx_init_zen2( cntx_t* cntx ) ( 12, // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, // dotxaxpyf BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v - BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); @@ -116,55 +118,54 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 29, - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - // axpbyv - BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, - BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, - BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, - BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, - //swap - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - //copy - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, // scal2v - BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -175,11 +176,11 @@ void bli_cntx_init_zen2( cntx_t* cntx ) #if AOCL_BLIS_MULTIINSTANCE bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 18 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 2040, 4080, 256 ); #else - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); #endif bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); @@ -204,31 +205,35 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // ------------------------------------------------------------------------- - //Initialize TRSM blocksize objects with architecture-specific values. - //Using different cache block sizes for TRSM instead of common level-3 block sizes. - //Tuning is done for double-precision only. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); - - // Update the context with the current architecture's register and cache - // blocksizes for level-3 TRSM problems. - bli_cntx_set_trsm_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); - - // Initialize sup thresholds with architecture-appropriate values. s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + // Initialize TRSM blocksize objects with architecture-specific values. + // Using different cache block sizes for TRSM instead of common level-3 block sizes. + // Tuning is done for double-precision only. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); // Initialize the context with the sup thresholds. bli_cntx_set_l3_sup_thresh @@ -244,16 +249,15 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l3_sup_handlers ( 2, - BLIS_GEMM, bli_gemmsup_ref, - BLIS_GEMMT, bli_gemmtsup_ref, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, cntx ); // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 30, - //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + 30, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, @@ -262,6 +266,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, @@ -270,6 +275,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, @@ -291,18 +297,19 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Initialize level-3 sup blocksize objects with architecture-specific // values. // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); // Update the context with the current architecture's register and cache // blocksizes for small/unpacked level-3 problems. bli_cntx_set_l3_sup_blkszs ( 5, + // level-3 BLIS_NC, &blkszs[ BLIS_NC ], BLIS_KC, &blkszs[ BLIS_KC ], BLIS_MC, &blkszs[ BLIS_MC ], @@ -311,4 +318,3 @@ void bli_cntx_init_zen2( cntx_t* cntx ) cntx ); } - diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index 16fe50609e..ecff86be2e 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,7 +38,7 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops -// to be not paralleized. +// to be not parallelized. #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -50,8 +50,8 @@ #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 -#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 -#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. diff --git a/config/zen2/make_defs.cmake b/config/zen2/make_defs.cmake new file mode 100644 index 0000000000..2296a3d2c2 --- /dev/null +++ b/config/zen2/make_defs.cmake @@ -0,0 +1,76 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# Include file containing common flags for all AMD architectures +include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS} /Oy) +else() + set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) +endif() + +# gcc or clang version must be at least 4.0 +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # gcc 9.0 or later + list(APPEND CKVECFLAGS -march=znver2) + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + else() + # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + list(APPEND CKVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + list(APPEND CRVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + endif() +endif() # gcc + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # AOCC clang has various formats for the version line + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") + string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") + string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") + + if("${CLANG_STRING}" MATCHES "AOCC_4") + # AOCC version 4x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2) + elseif("${CLANG_STRING}" MATCHES "AOCC_3") + # AOCC version 3x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2) + elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") + # AOCC version 2x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # LLVM clang 9.0 or later + list(APPEND CKVECFLAGS -march=znver2) + else() + list(APPEND CKVECFLAGS -march=znver1) + endif() +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index 180c201b06..b54ebda881 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/zen3/CMakeLists.txt b/config/zen3/CMakeLists.txt deleted file mode 100644 index d600e43870..0000000000 --- a/config/zen3/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc ## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen3.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_family_zen3.h - ) diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 31a9ff5957..b5b99eb609 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -37,272 +37,286 @@ void bli_cntx_init_zen3( cntx_t* cntx ) { - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - blksz_t thresh[ BLIS_NUM_THRESH ]; - // Set default kernel blocksizes and functions. - bli_cntx_init_zen3_ref( cntx ); - - // ------------------------------------------------------------------------- - - // Update the context with optimized native gemm micro-kernels and - // their storage preferences. - bli_cntx_set_l3_nat_ukrs - ( - 8, - // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, - // gemmtrsm_l - BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, - // gemmtrsm_u - BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, - cntx - ); - - // Update the context with architecture specific threshold functions - bli_cntx_set_l3_thresh_funcs - ( - 2, - // GEMMT - BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, - // SYRK - BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, - cntx - ); - - // packm kernels - bli_cntx_set_packm_kers - ( - 8, - BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, - BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, - BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, - BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, - BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, - cntx - ); - - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 12, - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, - // dotxaxpyf - BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, - BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, - BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, - // axpy2v - BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, - BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, - cntx - ); - - // Update the context with optimized level-1v kernels. - bli_cntx_set_l1v_kers - ( - 29, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - - // axpbyv - BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, - BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, - - // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, - BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, - - // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, - BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, - BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, - - // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, - - // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, - - //swap - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - - //copy - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, - - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - - // scal2v - BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, - cntx - ); - - // Initialize level-3 blocksize objects with architecture-specific values. + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen3_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 11, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen_asm_2x6, TRUE, + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_l_zen_asm_2x6, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_u_zen_asm_2x6, TRUE, + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 2, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 12, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 29, + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, + + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, + + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. // // These are reference block sizes and may be overridden based on // number of threads used at runtime. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); - - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 7, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - // level-1f - BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, - BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, - cntx - ); -// ------------------------------------------------------------------------- - - //Initialize TRSM blocksize objects with architecture-specific values. - //Using different cache block sizes for TRSM instead of common level-3 block sizes. - //Tuning is done for double-precision only. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); - - // Update the context with the current architecture's register and cache - // blocksizes for level-3 TRSM problems. - bli_cntx_set_trsm_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); - - // Initialize sup thresholds with architecture-appropriate values. s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); - - // Initialize the context with the sup thresholds. - bli_cntx_set_l3_sup_thresh - ( - 3, - BLIS_MT, &thresh[ BLIS_MT ], - BLIS_NT, &thresh[ BLIS_NT ], - BLIS_KT, &thresh[ BLIS_KT ], - cntx - ); - - // Initialize the context with the sup handlers. - bli_cntx_set_l3_sup_handlers - ( - 2, - BLIS_GEMM, bli_gemmsup_ref, - BLIS_GEMMT, bli_gemmtsup_ref, - cntx - ); - - // Update the context with optimized small/unpacked gemm kernels. - bli_cntx_set_l3_sup_kers - ( - 30, - //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - cntx - ); - - // Initialize level-3 sup blocksize objects with architecture-specific - // values. - // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); - - // Update the context with the current architecture's register and cache - // blocksizes for small/unpacked level-3 problems. - bli_cntx_set_l3_sup_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize TRSM blocksize objects with architecture-specific values. + // Using different cache block sizes for TRSM instead of common level-3 block sizes. + // Tuning is done for double-precision only. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 2 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 6 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 24 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 512 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 1536 ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); } diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index ce84104c52..35ffc9f19d 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,9 +38,7 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops -// to be not paralleized. -// - +// to be not parallelized. #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -52,7 +50,7 @@ #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 -#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 -#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 #endif diff --git a/config/zen3/make_defs.cmake b/config/zen3/make_defs.cmake new file mode 100644 index 0000000000..077deb68c3 --- /dev/null +++ b/config/zen3/make_defs.cmake @@ -0,0 +1,90 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# FLAGS that are specific to the 'zen3' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Include file containing common flags for all AMD architectures +include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) + +# --- Determine the C compiler and related flags --- +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS} /Oy) +else() + set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) + # gcc 11.0 or later + list(APPEND CKVECFLAGS -march=znver3) + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # gcc 9.0 or later + list(APPEND CKVECFLAGS -march=znver2) + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse) + else() + # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + list(APPEND CKVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + list(APPEND CRVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + endif() +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # AOCC clang has various formats for the version line + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") + string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") + string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") + + if("${CLANG_STRING}" MATCHES "AOCC_4") + # AOCC version 4x we will enable znver3 + list(APPEND CKVECFLAGS -march=znver3) + elseif("${CLANG_STRING}" MATCHES "AOCC_3") + # AOCC version 3x we will enable znver3 + list(APPEND CKVECFLAGS -march=znver3) + elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") + # AOCC version 2x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # LLVM clang 9.0 or later + list(APPEND CKVECFLAGS -march=znver2) + else() + list(APPEND CKVECFLAGS -march=znver1) + endif() +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index 7ec1ee32e9..727be9d603 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/zen4/CMakeLists.txt b/config/zen4/CMakeLists.txt deleted file mode 100644 index ea166b00c7..0000000000 --- a/config/zen4/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc ## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_family_zen4.h - ) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 8dda84ccce..8a79ff8a1f 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,420 +40,382 @@ */ #define BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs) \ - /* s d c z */ \ - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 128, 144, 60 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ - 480, 320, 256, 160 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4002, 4080, 2004 ); \ - \ - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 128, 144, 60 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ + 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4002, 4080, 2004 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ #define BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs) \ - /* s d c z */ \ - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 64, 144, 60 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ - 480, 320, 256, 160 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 3600, 4080, 2004 ); \ - \ - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 64, 144, 60 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ + 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 3600, 4080, 2004 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ void bli_cntx_init_zen4( cntx_t* cntx ) { - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - blksz_t thresh[ BLIS_NUM_THRESH ]; - // Set default kernel blocksizes and functions. - bli_cntx_init_zen4_ref( cntx ); - - // ------------------------------------------------------------------------- - - // Update the context with optimized native gemm micro-kernels and - // their storage preferences. - bli_cntx_set_l3_nat_ukrs - ( - 10, - // gemm - BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_32x6, FALSE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - /*bli_zgemm_zen4_asm_12x4 is a column preferred kernel*/ - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_12x4, FALSE, - - // Different GEMM kernels are used for TRSM for zen4 architecture - BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE, - - // gemmtrsm_l - BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen4_asm_8x24, TRUE, - // gemmtrsm_u - BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen4_asm_8x24, TRUE, - - cntx - ); - - // Update the context with architecture specific threshold functions - bli_cntx_set_l3_thresh_funcs - ( - 3, - // GEMM - BLIS_GEMM, bli_cntx_gemmsup_thresh_is_met_zen4, - // GEMMT - BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, - // SYRK - BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, - cntx - ); - - // packm kernels - bli_cntx_set_packm_kers - ( - 11, - BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, - BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, - BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_8xk, - BLIS_PACKM_24XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_24xk, - BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_32xk, - BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, - BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, - BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, - BLIS_PACKM_12XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_12xk, - BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_4xk, - cntx - ); - - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 9, - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, - BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, - // axpy2v - BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, - cntx - ); - - // Update the context with optimized level-1v kernels. - bli_cntx_set_l1v_kers - ( - 28, - - // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - - // axpbyv - BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, - BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, - BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, - - // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int_avx512, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int_avx512, - BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, - BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, - - // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int_avx512, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int_avx512, - BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, - BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, - - // dotxv - BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, - BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, - - // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int_avx512, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int_avx512, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, - - //swap - BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, - BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, - - //copy - BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, - BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, - BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, - - //set - BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, - BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, - - // scal2v - BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, - cntx - ); - - // Initialize level-3 blocksize objects with architecture-specific values. - // - // These are reference block sizes and may be overridden based on - // number of threads used at runtime. - - if ( bli_init_model_query_id() == BLIS_MODEL_BERGAMO ) - { - BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs); - } - else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_GENOA and BLIS_MODEL_GENOA_X - { - BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs); - } - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 7, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - // level-1f - BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, - BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, - cntx - ); - // ------------------------------------------------------------------------- - - // Initialize sup thresholds with architecture-appropriate values. s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 682, 1000, 380, 110 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 1000, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); - - // Initialize the context with the sup thresholds. - bli_cntx_set_l3_sup_thresh - ( - 3, - BLIS_MT, &thresh[ BLIS_MT ], - BLIS_NT, &thresh[ BLIS_NT ], - BLIS_KT, &thresh[ BLIS_KT ], - cntx - ); - - // Initialize the context with the sup handlers. - bli_cntx_set_l3_sup_handlers - ( - 2, - BLIS_GEMM, bli_gemmsup_ref, - BLIS_GEMMT, bli_gemmtsup_ref, - cntx - ); - - // Update the context with optimized small/unpacked gemm kernels. - bli_cntx_set_l3_sup_kers - ( - 30, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64m_avx512, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64n_avx512, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, - BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, - cntx - ); - - // Initialize level-3 sup blocksize objects with architecture-specific - // values. - // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 24, 3, 12, - 6, 9, 3, 12 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4080, 2040, 1020 ); - - // Update the context with the current architecture's register and cache - // blocksizes for small/unpacked level-3 problems. - bli_cntx_set_l3_sup_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); -} - -/* - * Override the block sizes in the context to the block sizes used - * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default - * GEMM kernels are AVX512 based and uses different block sizes. - * - * This function should be called in TRSM path before performing - * any packing operations. - * - * Also the context must be restored to default values by calling - * bli_zen4_restore_default_blkszs() before exiting TRSM Path - */ -void bli_zen4_override_trsm_blkszs (cntx_t* cntx) -{ - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 8, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 24, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 120, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4008, 4080, 4080 ); - - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 5, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - cntx - ); -} - - -// Since the output of syrk/gemmt is a triangular matrix, -// near-to-square shaped kernel performs better than -// skewed/rectangular shaped kernel. -// Hence we are overriding blocksizes and kernel -// function pointers for gemmt/syrk with avx2 specific ones -void bli_zen4_override_gemmt_blkszs (cntx_t* cntx) -{ - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_l3_sup_blkszs - ( - 4, - // level-3 - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); - - bli_cntx_set_l3_sup_kers - ( - 24, - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - cntx - ); -} - -/* - * Restore the block sizes to default values needed for zen4 context. - * - * This function should be called to restore the block sizes to there - * default values if they where overriden by calling - * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the - * TRSM path. - * - */ -void bli_zen4_restore_default_blkszs (cntx_t* cntx) -{ - blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - - if ( bli_init_model_query_id() == BLIS_MODEL_BERGAMO ) - { - BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs); - } - else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_GENOA and BLIS_MODEL_GENOA_X - { - BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs); - } - - // Update the context with the current architecture's register and cache - // blocksizes (and multiples) for native execution. - bli_cntx_set_blkszs - ( - BLIS_NAT, 7, - // level-3 - BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, - BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, - BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, - BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, - BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, - // level-1f - BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, - BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, - cntx - ); + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + + // Set default kernel blocksizes and functions. + bli_cntx_init_zen4_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 13, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_skx_asm_32x12_l2, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_32x6, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + /*bli_zgemm_zen4_asm_12x4 is a column preferred kernel*/ + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_12x4, FALSE, + + // Different GEMM kernels are used for TRSM for zen4 architecture + BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE, + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_4x12, TRUE, + + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen4_asm_8x24, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_l_zen4_asm_4x12, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen4_asm_8x24, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DCOMPLEX, bli_zgemmtrsm_u_zen4_asm_4x12, TRUE, + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 3, + // GEMM + BLIS_GEMM, bli_cntx_gemmsup_thresh_is_met_zen4, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // Update the context with optimized packm kernels. + bli_cntx_set_packm_kers + ( + 11, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_8xk, + BLIS_PACKM_24XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_24xk, + BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_32xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_12XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_12xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_4xk, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 9, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 28, + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int_avx512, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int_avx512, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int_avx512, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int_avx512, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, + + // swapv + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + // copyv + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, + + // setv + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // + // These are reference block sizes and may be overridden based on + // number of threads used at runtime. + + if ( bli_init_model_query_id() == BLIS_MODEL_BERGAMO ) + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs); + } + else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_GENOA and BLIS_MODEL_GENOA_X + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs); + } + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize TRSM blocksize objects with architecture-specific values. + // Using different cache block sizes for TRSM instead of common level-3 block sizes. + // Tuning is done for double-precision only. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 8, 3, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 24, 8, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 120, 144, 40 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 512 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4008, 4080, 2004 ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // ------------------------------------------------------------------------- + + // Initialize sup thresholds with architecture-appropriate values. + // s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 682, 1000, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 1000, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64m_avx512, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64n_avx512, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 24, 3, 12, + 6, 9, 3, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // Initialize level-3 sup blocksize objects for operations dealing with + // triangular objects with architecture-specific values. + // + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_l3_sup_tri_blkszs + ( + 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + bli_cntx_set_l3_sup_tri_kers + ( + 30, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); } diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index a1666ea9d3..bacf8b62a4 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,16 +37,15 @@ // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops -// to be not paralleized. -// -#define BLIS_THREAD_MAX_IR 1 -#define BLIS_THREAD_MAX_JR 1 +// to be not parallelized. +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 #define BLIS_ENABLE_SMALL_MATRIX #define BLIS_ENABLE_SMALL_MATRIX_TRSM // This will select the threshold below which small matrix code will be called. -#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_MATRIX_THRES 700 #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 @@ -60,30 +59,4 @@ #define BLIS_SIMD_SIZE 64 #define BLIS_SIMD_NUM_REGISTERS 32 -/* - * Override the block sizes in the context to the block sizes used - * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default - * GEMM kernels are AVX512 based and uses different block sizes. - * - * This function should be called in TRSM path before performing - * any packing operations. - * - * Also the context must be restored to default values by calling - * bli_zen4_restore_default_blkszs() before exiting TRSM Path - */ -BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); - -BLIS_EXPORT_BLIS void bli_zen4_override_gemmt_blkszs (cntx_t* cntx); - -/* - * Restore the block sizes to default values needed for zen4 context. - * - * This function should be called to restore the block sizes to there - * default values if they where overriden by calling - * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the - * TRSM path. - * - */ -BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); - #endif diff --git a/config/zen4/make_defs.cmake b/config/zen4/make_defs.cmake new file mode 100644 index 0000000000..e5ce4401b7 --- /dev/null +++ b/config/zen4/make_defs.cmake @@ -0,0 +1,112 @@ +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. ## + +# FLAGS that are specific to the 'zen4' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Include file containing common flags for all AMD architectures +include(${CMAKE_SOURCE_DIR}/config/zen/amd_config.cmake) +if(NOT WIN32) + if(NOT (DEBUG_TYPE STREQUAL "off")) + set(CDBGFLAGS -g) + endif() + + if(DEBUG_TYPE STREQUAL "noopt") + set(COPTFLAGS -O0) + else() # off or opt + set(COPTFLAGS -O3) + endif() +endif() + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +if(MSVC) + set(CKOPTFLAGS ${COPTFLAGS} /Oy) +else() + set(CKOPTFLAGS ${COPTFLAGS} -fomit-frame-pointer) +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) + # gcc 13.0 or later + list(APPEND CKVECFLAGS -march=znver4) + list(APPEND CRVECFLAGS -march=znver4) + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0.0) + # gcc 11.0 or later + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16) + list(APPEND CRVECFLAGS -march=znver3) + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # gcc 9.0 or later + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CRVECFLAGS -march=znver2) + list(APPEND CKOPTFLAGS -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.0.0) + # gcc 8.0 or later + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CRVECFLAGS -march=znver1) + else() + # If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + list(APPEND CKVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + list(APPEND CRVECFLAGS -march=znver1 -mno-avx256-split-unaligned-store) + endif() +endif() # gcc + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # AOCC clang has various formats for the version line + + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string) + string(REGEX MATCH "^[^\n]*" CLANG_VERSION_STRING "${clang_full_version_string}") + string(REGEX MATCHALL "(AOCC_2|AOCC_3|AOCC_4|AOCC|LLVM|clang)" CLANG_STRING "${CLANG_VERSION_STRING}") + string(REGEX REPLACE ".*clang version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION "${CLANG_VERSION_STRING}") + + if("${CLANG_STRING}" MATCHES "AOCC_4") + # AOCC version 4x we will enable znver4 + list(APPEND CKVECFLAGS -march=znver4 -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver4) + elseif("${CLANG_STRING}" MATCHES "AOCC_3") + # AOCC version 3x we will enable znver3 + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver3) + elseif("${CLANG_STRING}" MATCHES "(AOCC_2|LLVM)") + # AOCC version 2x we will enable znver2 + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni) + list(APPEND CRVECFLAGS -march=znver2) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 16.0.0) + # LLVM clang 16.0 or later + list(APPEND CKVECFLAGS -march=znver4 -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver4) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0.0) + # LLVM clang 13.0 or later + list(APPEND CKVECFLAGS -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver3) + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 9.0.0) + # LLVM clang 9.0 or later + list(APPEND CKVECFLAGS -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver2) + else() + list(APPEND CKVECFLAGS -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -falign-loops=64) + list(APPEND CRVECFLAGS -march=znver1) + endif() +endif() + +# Flags specific to reference kernels. +set(CROPTFLAGS ${CKOPTFLAGS}) +set(CRVECFLAGS ${CKVECFLAGS}) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 5a058e2fbc..bca80fcc9f 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config_registry b/config_registry index 4e6716dfa1..cd0f9bbb68 100644 --- a/config_registry +++ b/config_registry @@ -8,7 +8,7 @@ # # Processor families. -x86_64: intel64 amd64 amd64_legacy +x86_64: intel64 amdzen amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic amdzen: zen4 zen3 zen2 zen generic @@ -19,7 +19,7 @@ amdzen: zen4 zen3 zen2 zen generic #arm32: cortexa15 cortexa9 generic # Intel architectures. -skx: skx/skx/haswell/zen +skx: skx/skx/haswell/zen/zen4 knl: knl/knl/haswell/zen haswell: haswell/haswell/zen sandybridge: sandybridge @@ -36,6 +36,8 @@ piledriver: piledriver bulldozer: bulldozer # ARM architectures. +armsve: armsve/armsve +a64fx: a64fx/armsve thunderx2: thunderx2/armv8a cortexa57: cortexa57/armv8a cortexa53: cortexa53/armv8a diff --git a/configure b/configure index a165c1ad51..92a34632bb 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -355,11 +355,12 @@ print_usage() echo " " echo " --enable-blis-arch-type, --disable-blis-arch-type" echo " " - echo " Disable (Enabled by default) support for BLIS_ARCH_TYPE and BLIS_MODEL_TYPE" - echo " environment variables, which allows user to select" + echo " Disable support for AOCL_ENABLE_INSTRUCTIONS, BLIS_ARCH_TYPE and" + echo " BLIS_MODEL_TYPE environment variables, which allows user to select" echo " architecture specific code path and optimizations at runtime." echo " If disabled, in builds with multiple code paths, BLIS" echo " will still select path and optimizations automatically." + echo " Default: Enabled in builds with multiple code paths, else disabled." echo " " echo " --rename-blis-arch-type=STRING" echo " " @@ -811,7 +812,8 @@ read_registry_file() # canonicalize whitespace, and then remove duplicate kernel # set names, if they exist. Finally, update the kernel registry # with the new kernel list. - newklist=$(echo -e "${klisttmp}" | sed -e "s/${ker}/${kers_ker}/g") + #newklist=$(echo -e "${klisttmp}" | sed -e "s/${ker}/${kers_ker}/g") + newklist=$(substitute_words "${ker}" "${kers_ker}" "${klisttmp}") newklist=$(canonicalize_ws "${newklist}") newklist=$(rm_duplicate_words "${newklist}") @@ -833,6 +835,26 @@ read_registry_file() done } +substitute_words() +{ + local word new_words list newlist + + word="$1" + new_words="$2" + list="$3" + + for str in ${list}; do + + if [ "${str}" == "${word}" ]; then + newlist="${newlist} ${new_words}" + else + newlist="${newlist} ${str}" + fi + done + + echo "${newlist}" +} + build_kconfig_registry() { local familyname clist config kernels kernel cur_configs newvalue @@ -1453,7 +1475,7 @@ get_compiler_version() cc_version=$(${cc} -dumpversion) # If compiler is AOCC, first grep for clang and then the version number. elif [ "${cc_vendor}" = "clang" ]; then - cc_version=$(echo "${vendor_string}" | egrep -o 'clang version [0-9]+\.[0-9]+\.?[0-9]*' | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*') + cc_version=$(echo "${vendor_string}" | egrep -o '(clang|LLVM) version [0-9]+\.[0-9]+\.?[0-9]*' | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*') elif [ "${cc_vendor}" = "oneAPI" ]; then # Treat Intel oneAPI's clang as clang, not icc. cc_vendor="clang" @@ -1519,6 +1541,8 @@ check_compiler() # cortexa15: any # cortexa9: any # + # armsve: clang11+, gcc10+ + # # generic: any # # Note: These compiler requirements were originally modeled after similar @@ -1564,6 +1588,9 @@ check_compiler() # gcc 5.x may support POWER9 but it is unverified. blacklistcc_add "power9" fi + if [ ${cc_major} -lt 10 ]; then + blacklistcc_add "armsve" + fi fi # icc @@ -1626,6 +1653,9 @@ check_compiler() #blacklistcc_add "zen" : # explicit no-op since bash can't handle empty loop bodies. fi + if [ ${cc_major} -lt 11 ]; then + blacklistcc_add "armsve" + fi fi fi } @@ -2047,7 +2077,7 @@ main() enable_aocl_dynamic='yes' force_version='no' complex_return='default' - disable_blis_arch_type='no' + disable_blis_arch_type='unset' rename_blis_arch_type='BLIS_ARCH_TYPE' rename_blis_model_type='BLIS_MODEL_TYPE' @@ -2427,6 +2457,11 @@ main() echo "${script_name}: using '${found_cc}' C compiler." + # Also check the compiler to see if we are (cross-)compiling for Windows + if ${found_cc} -dM -E - < /dev/null 2> /dev/null | grep -q _WIN32; then + is_win=yes + fi + # -- Find a C++ compiler --------------------------------------------------- @@ -2776,6 +2811,19 @@ main() fi + # Based on the number of sub-configurations, set default value for disable_blis_arch_type + # (if user hasn't set option). BLIS_ARCH_TYPE functionality only makes sense for use with + # processor families containing multiple sub-configurations, but user can force the + # functionality to be enabled/disabled with --enable-blis-arch-type/--disable-blis-arch-type + # configure options. + if [ "x${disable_blis_arch_type}" = "xunset" ]; then + config_list_count=$(echo ${config_list} |wc -w) + if [ "x${config_list_count}" = "x1" ]; then + disable_blis_arch_type='yes' + else + disable_blis_arch_type='no' + fi + fi echo "${script_name}: checking sub-configurations:" @@ -3267,7 +3315,8 @@ main() fi if [ "x${disable_blis_arch_type}" = "xyes" ]; then - echo "${script_name}: user selection of code path using BLIS_ARCH_TYPE and BLIS_MODEL_TYPE env vars is disabled." + echo "${script_name}: user selection of code path using AOCL_ENABLE_INSTRUCTIONS," + echo "${script_name}: BLIS_ARCH_TYPE and BLIS_MODEL_TYPE env vars is disabled." disable_blis_arch_type_01='1' else disable_blis_arch_type_01='0' @@ -3332,10 +3381,11 @@ main() uconf=$(echo ${config_name} | tr '[:lower:]' '[:upper:]') config_name_define="#define BLIS_FAMILY_${uconf}\n" - #create a AOCL specific #define - #This macro is enabled only for zen family configurations. - #This enables us to use different cache block sizes for TRSM instead of common level-3 block sizes. - uconf=$(echo ${config_name} | grep -c 'zen\|amd64' | cut -d. -f1) + # Create a AOCL specific #define + # This macro is enabled only for zen family configurations. + # This enables us to use different cache block sizes for TRSM instead of common level-3 block sizes. + # Note: amd64_legacy is for pre-zen architectures. + uconf=$(echo ${config_name} | grep -v amd64_legacy |grep -c 'zen\|amd64\|x86_64' | cut -d. -f1) if [[ $uconf == 1 ]]; then enable_aocl_zen='yes' enable_aocl_zen_01=1 @@ -3822,6 +3872,23 @@ main() exit 1 fi + # If 'blis.pc.in' symlink does not already exist in the current + # directory, create a symbolic link to it. If one does exist, we + # use -f to force creation of a new link. + if [ ! -e "./blis.pc.in" ]; then + + echo "${script_name}: creating symbolic link to blis.pc.in." + ln -s "${dist_path}/blis.pc.in" + + elif [ -h "./blis.pc.in" ]; then + echo "${script_name}: symbolic link to blis.pc.in already exists; forcing creation of new link." + ln -sf "${dist_path}/blis.pc.in" + else + echo "${script_name}: Non-symbolic link file or directory 'blis.pc.in' blocks creation of symlink." + echo "${script_name}: *** Please remove this entity and re-run configure." + exit 1 + fi + # If 'common.mk' symlink does not already exist in the current # directory, create a symbolic link to it. If one does exist, we # use -f to force creation of a new link. diff --git a/docs/BLISObjectAPI.md b/docs/BLISObjectAPI.md index e84703cdcc..9a06e29a49 100644 --- a/docs/BLISObjectAPI.md +++ b/docs/BLISObjectAPI.md @@ -53,7 +53,7 @@ This index provides a quick way to jump directly to the description for each ope * **[Level-3](BLISObjectAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: * [gemm](BLISObjectAPI.md#gemm), [hemm](BLISObjectAPI.md#hemm), [herk](BLISObjectAPI.md#herk), [her2k](BLISObjectAPI.md#her2k), [symm](BLISObjectAPI.md#symm), [syrk](BLISObjectAPI.md#syrk), [syr2k](BLISObjectAPI.md#syr2k), [trmm](BLISObjectAPI.md#trmm), [trmm3](BLISObjectAPI.md#trmm3), [trsm](BLISObjectAPI.md#trsm) * **[Utility](BLISObjectAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: - * [asumv](BLISObjectAPI.md#asumv), [norm1v](BLISObjectAPI.md#norm1v), [normfv](BLISObjectAPI.md#normfv), [normiv](BLISObjectAPI.md#normiv), [norm1m](BLISObjectAPI.md#norm1m), [normfm](BLISObjectAPI.md#normfm), [normim](BLISObjectAPI.md#normim), [mkherm](BLISObjectAPI.md#mkherm), [mksymm](BLISObjectAPI.md#mksymm), [mktrim](BLISObjectAPI.md#mktrim), [fprintv](BLISObjectAPI.md#fprintv), [fprintm](BLISObjectAPI.md#fprintm),[printv](BLISObjectAPI.md#printv), [printm](BLISObjectAPI.md#printm), [randv](BLISObjectAPI.md#randv), [randm](BLISObjectAPI.md#randm), [sumsqv](BLISObjectAPI.md#sumsqv), [getijm](BLISObjectAPI.md#getijm), [setijm](BLISObjectAPI.md#setijm) + * [asumv](BLISObjectAPI.md#asumv), [norm1v](BLISObjectAPI.md#norm1v), [normfv](BLISObjectAPI.md#normfv), [normiv](BLISObjectAPI.md#normiv), [norm1m](BLISObjectAPI.md#norm1m), [normfm](BLISObjectAPI.md#normfm), [normim](BLISObjectAPI.md#normim), [mkherm](BLISObjectAPI.md#mkherm), [mksymm](BLISObjectAPI.md#mksymm), [mktrim](BLISObjectAPI.md#mktrim), [fprintv](BLISObjectAPI.md#fprintv), [fprintm](BLISObjectAPI.md#fprintm),[printv](BLISObjectAPI.md#printv), [printm](BLISObjectAPI.md#printm), [randv](BLISObjectAPI.md#randv), [randm](BLISObjectAPI.md#randm), [sumsqv](BLISObjectAPI.md#sumsqv), [getsc](BLISObjectAPI.md#getsc), [getijv](BLISObjectAPI.md#getijv), [getijm](BLISObjectAPI.md#getijm), [setsc](BLISObjectAPI.md#setsc), [setijv](BLISObjectAPI.md#setijv), [setijm](BLISObjectAPI.md#setijm), [eqsc](BLISObjectAPI.md#eqsc), [eqv](BLISObjectAPI.md#eqv), [eqm](BLISObjectAPI.md#eqm) @@ -790,6 +790,8 @@ Perform ``` where `x` and `y` are vectors of length _n_. +Observed object properties: `conj?(x)`. + --- #### dotv @@ -807,6 +809,8 @@ Perform ``` where `x` and `y` are vectors of length _n_, and `rho` is a scalar. +Observed object properties: `conj?(x)`, `conj?(y)`. + --- #### dotxv @@ -826,6 +830,8 @@ Perform ``` where `x` and `y` are vectors of length _n_, and `alpha`, `beta`, and `rho` are scalars. +Observed object properties: `conj?(alpha)`, `conj?(beta)`, `conj?(x)`, `conj?(y)`. + --- #### invertv @@ -2125,6 +2131,34 @@ where, on entry, `scale` and `sumsq` contain `scale_old` and `sumsq_old`, respec --- +#### getsc +```c +void bli_getsc + ( + obj_t* chi, + double* zeta_r, + double* zeta_i + ) +``` +Copy the real and imaginary values from the scalar object `chi` to `zeta_r` and `zeta_i`. If `chi` is stored as a real type, then `zeta_i` is set to zero. (If `chi` is stored in single precision, the corresponding elements are typecast/promoted during the copy.) + +--- + +#### getijv +```c +err_t bli_getijv + ( + dim_t i, + obj_t* b, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the `i`th element of vector object `x` to `ar` and `ai`. If elements of `x` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `x` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +If either the element offset `i` is beyond the vector dimension of `x` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `x` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + #### getijm ```c err_t bli_getijm @@ -2136,8 +2170,38 @@ err_t bli_getijm double* ai ) ``` -Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. f elements of `b` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) -If either the row offset `i` is beyond the _m_ dimension of `b`, or column offset `j` is beyond the _n_ dimension of `b`, the function does not perform any copy and returns `BLIS_FAILURE`. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, `BLIS_FAILURE` is returned. +Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. If elements of `b` are stored as real types, then only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +If either the row offset `i` is beyond the _m_ dimension of `b` or less than zero, or column offset `j` is beyond the _n_ dimension of `b` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + +#### setsc +```c +void bli_setsc + ( + double* zeta_r, + double* zeta_i, + obj_t* chi + ); +``` +Copy real and imaginary values `zeta_r` and `zeta_i` to the scalar object `chi`. If `chi` is stored as a real type, then `zeta_i` is ignored. (If `chi` is stored in single precision, the contents are typecast/demoted during the copy.) + +--- + +#### setijv +```c +err_t bli_setijv + ( + double ar, + double ai, + dim_t i, + obj_t* x + ); +``` +Copy real and imaginary values `ar` and `ai` to the `i`th element of vector object `x`. If elements of `x` are stored as real types, then only `ar` is copied and `ai` is ignored. (If `x` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +If the element offset `i` is beyond the vector dimension of `x` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `x` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- #### setijm ```c @@ -2151,7 +2215,59 @@ err_t bli_setijm ); ``` Copy real and imaginary values `ar` and `ai` to the (`i`,`j`) element of object `b`. If elements of `b` are stored as real types, then only `ar` is copied and `ai` is ignored. (If `b` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) -If either the row offset `i` is beyond the _m_ dimension of `b`, or column offset `j` is beyond the _n_ dimension of `b`, the function does not perform any copy and returns `BLIS_FAILURE`. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, `BLIS_FAILURE` is returned. +If either the row offset `i` is beyond the _m_ dimension of `b` or less than zero, or column offset `j` is beyond the _n_ dimension of `b` or less than zero, the function returns `BLIS_FAILURE` without taking any action. Similarly, if `b` is a global scalar constant such as `BLIS_ONE`, the function returns `BLIS_FAILURE`. + +--- + +#### eqsc +```c +void bli_eqsc + ( + obj_t chi, + obj_t psi, + bool* is_eq + ); +``` +Perform an element-wise comparison between scalars `chi` and `psi` and store the boolean result in the `bool` pointed to by `is_eq`. +If exactly one of `conj(chi)` or `conj(psi)` (but not both) indicate a conjugation, then one of the scalars will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `conj?(chi)`, `conj?(psi)`. + +--- + +#### eqv +```c +void bli_eqv + ( + obj_t x, + obj_t y, + bool* is_eq + ); +``` +Perform an element-wise comparison between vectors `x` and `y` and store the boolean result in the `bool` pointed to by `is_eq`. +If exactly one of `conj(x)` or `conj(y)` (but not both) indicate a conjugation, then one of the vectors will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `conj?(x)`, `conj?(y)`. + +--- + +#### eqm +```c +void bli_eqm + ( + obj_t a, + obj_t b, + bool* is_eq + ); +``` +Perform an element-wise comparison between matrices `A` and `B` and store the boolean result in the `bool` pointed to by `is_eq`. +Here, `A` is stored as a dense matrix, or lower- or upper-triangular/trapezoidal matrix with arbitrary diagonal offset and unit or non-unit diagonal. +If `diag(A)` indicates a unit diagonal, the diagonals of both matrices will be ignored for purposes of the comparision. +If `uplo(A)` indicates lower or upper storage, only that part of both matrices `A` and `B` will be referenced. +If exactly one of `trans(A)` or `trans(B)` (but not both) indicate a transposition, then one of the matrices will be transposed for purposes of the comparison. +Similarly, if exactly one of `trans(A)` or `trans(B)` (but not both) indicate a conjugation, then one of the matrices will be implicitly conjugated for purposes of the comparision. + +Observed object properties: `diagoff(A)`, `diag(A)`, `uplo(A)`, `trans?(A)`, `trans?(B)`. diff --git a/docs/BLISTypedAPI.md b/docs/BLISTypedAPI.md index e495aa00a8..a29870169d 100644 --- a/docs/BLISTypedAPI.md +++ b/docs/BLISTypedAPI.md @@ -48,7 +48,7 @@ This index provides a quick way to jump directly to the description for each ope * **[Level-3](BLISTypedAPI.md#level-3-operations)**: Operations with matrices that are multiplication-like: * [gemm](BLISTypedAPI.md#gemm), [hemm](BLISTypedAPI.md#hemm), [herk](BLISTypedAPI.md#herk), [her2k](BLISTypedAPI.md#her2k), [symm](BLISTypedAPI.md#symm), [syrk](BLISTypedAPI.md#syrk), [syr2k](BLISTypedAPI.md#syr2k), [trmm](BLISTypedAPI.md#trmm), [trmm3](BLISTypedAPI.md#trmm3), [trsm](BLISTypedAPI.md#trsm) * **[Utility](BLISTypedAPI.md#Utility-operations)**: Miscellaneous operations on matrices and vectors: - * [asumv](BLISTypedAPI.md#asumv), [norm1v](BLISTypedAPI.md#norm1v), [normfv](BLISTypedAPI.md#normfv), [normiv](BLISTypedAPI.md#normiv), [norm1m](BLISTypedAPI.md#norm1m), [normfm](BLISTypedAPI.md#normfm), [normim](BLISTypedAPI.md#normim), [mkherm](BLISTypedAPI.md#mkherm), [mksymm](BLISTypedAPI.md#mksymm), [mktrim](BLISTypedAPI.md#mktrim), [fprintv](BLISTypedAPI.md#fprintv), [fprintm](BLISTypedAPI.md#fprintm),[printv](BLISTypedAPI.md#printv), [printm](BLISTypedAPI.md#printm), [randv](BLISTypedAPI.md#randv), [randm](BLISTypedAPI.md#randm), [sumsqv](BLISTypedAPI.md#sumsqv) + * [asumv](BLISTypedAPI.md#asumv), [norm1v](BLISTypedAPI.md#norm1v), [normfv](BLISTypedAPI.md#normfv), [normiv](BLISTypedAPI.md#normiv), [norm1m](BLISTypedAPI.md#norm1m), [normfm](BLISTypedAPI.md#normfm), [normim](BLISTypedAPI.md#normim), [mkherm](BLISTypedAPI.md#mkherm), [mksymm](BLISTypedAPI.md#mksymm), [mktrim](BLISTypedAPI.md#mktrim), [fprintv](BLISTypedAPI.md#fprintv), [fprintm](BLISTypedAPI.md#fprintm),[printv](BLISTypedAPI.md#printv), [printm](BLISTypedAPI.md#printm), [randv](BLISTypedAPI.md#randv), [randm](BLISTypedAPI.md#randm), [sumsqv](BLISTypedAPI.md#sumsqv), [getsc](BLISTypedAPI.md#getsc), [getijv](BLISTypedAPI.md#getijv), [getijm](BLISTypedAPI.md#getijm), [setsc](BLISTypedAPI.md#setsc), [setijv](BLISTypedAPI.md#setijv), [setijm](BLISTypedAPI.md#setijm), [eqsc](BLISTypedAPI.md#eqsc), [eqv](BLISTypedAPI.md#eqv), [eqm](BLISTypedAPI.md#eqm) @@ -1695,6 +1695,149 @@ where, on entry, `scale` and `sumsq` contain `scale_old` and `sumsq_old`, respec --- +#### getsc +```c +void bli_getsc + ( + ctype* chi, + double* zeta_r, + double* zeta_i + ) +``` +Copy the real and imaginary values from the scalar object `chi` to `zeta_r` and `zeta_i`. If `chi` is stored as a real type, then `zeta_i` is set to zero. (If `chi` is stored in single precision, the corresponding elements are typecast/promoted during the copy.) + +--- + +#### getijv +```c +err_t bli_?getijv + ( + dim_t i, + ctype* x, incx, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the `i`th element of vector `x` to `ar` and `ai`. For real domain invocations, only `ar` is overwritten and `ai` is left unchanged. (If `x` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +Note that the object-based analogue of [getijv](BLISObjectAPI.md#getijv) does bounds checking of the vector element offset `i` against the vector length while the typed functions specified above do not (since the vector length is not given). + +--- + +#### getijm +```c +err_t bli_?getijm + ( + dim_t i, + dim_t j, + ctype* b, inc_t rs_b, inc_t cs_b, + double* ar, + double* ai + ) +``` +Copy the real and imaginary values at the (`i`,`j`) element of object `b` to `ar` and `ai`. For real domain invocations, only `ar` is overwritten and `ai` is left unchanged. (If `b` contains elements stored in single precision, the corresponding elements are typecast/promoted during the copy.) +Note that the object-based analogue of [getijm](BLISObjectAPI.md#getijm) does bounds checking of the matrix element offsets (`i`,`j`) against the matrix dimensions while the typed functions specified above do not (since the matrix dimensions are not given). + +--- + +#### setsc +```c +void bli_setsc + ( + double* zeta_r, + double* zeta_i, + ctype* chi + ); +``` +Copy real and imaginary values `zeta_r` and `zeta_i` to the scalar object `chi`. If `chi` is stored as a real type, then `zeta_i` is ignored. (If `chi` is stored in single precision, the contents are typecast/demoted during the copy.) + +--- + +#### setijv +```c +err_t bli_?setijv + ( + double ar, + double ai, + dim_t i, + ctype* x, incx + ); +``` +Copy real and imaginary values `ar` and `ai` to the `i`th element of vector object `x`. For real domain invocations, only `ar` is copied and `ai` is ignored. (If `x` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +Note that the object-based analogue of [setijv](BLISObjectAPI.md#setijv) does bounds checking of the vector element offset `i` against the vector length while the typed functions specified above do not (since the vector length is not given). + +--- + +#### setijm +```c +err_t bli_?setijm + ( + double ar, + double ai, + dim_t i, + dim_t j, + ctype* b, inc_t rs_b, inc_t cs_b + ); +``` +Copy real and imaginary values `ar` and `ai` to the (`i`,`j`) element of object `b`. For real domain invocations, only `ar` is copied and `ai` is ignored. (If `b` contains elements stored in single precision, the corresponding elements are typecast/demoted during the copy.) +Note that the object-based analogue of [setijm](BLISObjectAPI.md#setijm) does bounds checking of the matrix element offsets (`i`,`j`) against the matrix dimensions while the typed functions specified above do not (since the matrix dimensions are not given). + +--- + +#### eqsc +```c +void bli_?eqsc + ( + conj_t conjchi, + ctype* chi, + ctype* psi, + bool* is_eq + ); +``` +Perform an element-wise comparison between scalars `chi` and `psi` and store the boolean result in the `bool` pointed to by `is_eq`. +If `conjchi` indicates a conjugation, `chi` will be implicitly conjugated for purposes of the comparision. + +--- + +#### eqv +```c +void bli_?eqv + ( + conj_t conjx, + dim_t n, + ctype* x, inc_t incx, + ctype* y, inc_t incy, + bool* is_eq + ); +``` +Perform an element-wise comparison between length _n_ vectors `x` and `y` and store the boolean result in the `bool` pointed to by `is_eq`. +If `conjx` indicates a conjugation, `x` will be implicitly conjugated for purposes of the comparision. + +--- + +#### eqm +```c +void bli_?eqm + ( + doff_t diagoffa, + diag_t diaga, + uplo_t uploa, + trans_t transa, + dim_t m, + dim_t n, + ctype* a, inc_t rs_a, inc_t cs_a, + ctype* b, inc_t rs_b, inc_t cs_b, + bool* is_eq + ) +``` +Perform an element-wise comparison between matrices `A` and `B` and store the boolean result in the `bool` pointed to by `is_eq`. +Here, `B` is an _m x n_ matrix, `A` is stored as a dense matrix, or lower- or upper-triangular/trapezoidal matrix with arbitrary diagonal offset and unit or non-unit diagonal. +If `diaga` indicates a unit diagonal, the diagonals of both matrices will be ignored for purposes of the comparision. +If `uploa` indicates lower or upper storage, only that part of matrix `A` will be referenced in the comparison. +If `transa` indicates a conjugation and/or transposition, then `A` will be conjugated and/or transposed for purposes of the comparison. + + + + ## Level-3 microkernels diff --git a/docs/CMakeBuildSystem.md b/docs/CMakeBuildSystem.md new file mode 100644 index 0000000000..92b85cf432 --- /dev/null +++ b/docs/CMakeBuildSystem.md @@ -0,0 +1,225 @@ +## Contents + +* **[Contents](CMakeBuildSystem.md#contents)** +* **[Introduction](CMakeBuildSystem.md#introduction)** +* **[Step 1: Chose a framework configuration](CMakeBuildSystem.md#step-1-choose-a-framework-configuration)** +* **[Step 2: Configuring CMake](CMakeBuildSystem.md#step-2-configuring-cmake)** +* **[Step 3: Compilation](CMakeBuildSystem.md#step-3-compilation)** +* **[Step 4: Installation](CMakeBuildSystem.md#step-4-installation)** +* **[Compiling with BLIS](CMakeBuildSystem.md#compiling-with-blis)** +* **[Uninstalling](CMakeBuildSystem.md#uninstalling)** +* **[Available targets](CMakeBuildSystem.md#available-targets)** +* **[Adding configurations](CMakeBuildSystem.md#adding-configurations)** +* **[Some examples](CMakeBuildSystem.md#some-examples)** +* **[Final notes](CMakeBuildSystem.md#final-notes)** + +## Introduction + +This document describes how to use CMake to build and install a BLIS library to your local system. + +The BLIS CMake system is based on the [Make build system](BuildSystem.md) and is designed for use with both Linux and Windows. Other requirements are: + + * CMake (3.15.0 or higher) + * Python (3.4 or later for python3) + * GNU `make` (3.81 or later) on Linux + * Visual Studio 17 2022 on Windows + * a working C99 compiler (gcc or clang on Linux and **only** clang-cl on Windows) + +**_NOTE:_** +To get clang-cl on Visual Studio, one needs to choose "C++ Clang tools for Windows" when installing "Desktop development with C++" with Visual Studio. + +Note that, on Windows, BLIS implements basic pthreads functionality automatically, so a POSIX threads library is not required. On Linux, the implementation is the same to the one of the Make system. + +CMake is used to build out of source, so we need to start by creating a build directory from which we will do the configuration and build. Since there is a directory called blis/build, the build directory must have a different name. Here is an example of creating the directory: +``` +$ mkdir build_blis +$ cd build_blis +``` + +## Step 1: Choose a framework configuration + +The first step is to choose the appropriate BLIS configuration. As on the Make build system, the user must decide which configuration to use or whether automatic hardware detection should be used to determine the configuration. Currently only the following configurations are supported: + + * amdzen + * zen + * zen2 + * zen3 + * zen4 + * generic + +Instructions on how to add a configuration on the CMake system, are provided in [Adding configurations](CMakeBuildSystem.md#adding-configurations). + +### Multithreading + +As in Make system, multithreading in BLIS is disabled by default. To configure cmake so that OpenMP is used, please use `-DTHREADING_MODEL=openmp`. All available options can be found if cmake-gui is used, or by running +``` +cmake .. -DPRINT_CONFIGURE_HELP=ON +``` + +## Step 2: Configuring CMake + +### Choosing a generator + +This is a reminder on how to configure CMake to use a specific generator: +``` +cmake -G +``` + +On Linux "Unix Makefiles" is used by default and `-G ` can be omitted. + +On Windows, specify Visual Studio generator using +``` +cmake -G "Visual Studio 17 2022" +``` + +For the rest of this documentation, we will use the platform-agnostic commands to build the libraries, but the usual make commands can be used instead. On the following command snippets we ommit specifying the generator, but one can use their prefered way of building using common CMake practices. + +### Choosing a configuration + +This step is equivalent to running `./configure ` using the Make system. In this case, simply run: +``` +cmake .. -DBLIS_CONFIG_FAMILY= +``` +If the provided configuration is not supported, an error will be thrown and a message with the available configurations will be printed. + +To configure based on your hardware, you can configure using +``` +cmake .. -DBLIS_CONFIG_FAMILY=auto +``` +Please note that when `auto` is used as a configuration option, the `generic` configuration will be chosen by default on non-AMD hardware. + +### Specifying a prefix path for installation + +We remind users that to specify the installation prefix in cmake, one needs to configure using `CMAKE_INSTALL_PREFIX` variable: +``` +cmake .. -DBLIS_CONFIG_FAMILY=auto -DCMAKE_INSTALL_PREFIX= +``` +This will cause libraries to eventually be installed to `/lib` and headers will be installed to `/include`. + +Option to specify the library install and the header install separately, like in Make system, is not currently supported by the CMake equivalent. + +## Step 3: Compilation + +Once configuration is finished and the corresponding platform-dependent build files have been generated, you can proceed to building the library. +To build the library in a platform agnostic way use: +``` +cmake --build . --config Release +``` +For a verbose build, you can use: +``` +cmake --build . --verbose --config Release +``` +To build in parallel on a multicore system, you can use: +``` +cmake --build . --config Release -j +``` +where `` is the number of jobs allowed to run simultaneously by this command. + +Note that on Linux, if Makefiles are used, the above is equivalent to running +``` +make -j +``` + +## Step 4: Installation + +The BLIS library resides in your chosen build directory, say `blis/build_blis` and the generated header files are in `blis/build_blis/include/`. To install the library and the header files associated with it, you can use: +``` +cmake --build . --target install +``` +This will install the libraries and header files and create the corresponding symbolic links of the shared libraries in the path specified in `CMAKE_INSTALL_PREFIX`. + +Note that on Linux, if Makefiles are used, the above is equivalent to running +``` +make install +``` + +## Uninstalling + +Please note that CMake does not provide functionality to uninstall targets. + +## Available targets + +The BLIS CMake system aims to be combatible with the current `make` system. For that reason, it implements the same targets for the generation of libraries and the tests. The table of avalable targets can be found below. + +| target | Description | +|:----------------|:---------------------------------------------------| +| `all` | Execute `libs` target. | +| `libs` | Compile BLIS as a static and/or shared library (depending on CMake options). | +| `test` | Execute `checkblis` and `checkblas` targets. | +| `check` | Execute `checkblis-fast` and `checkblas` targets. | +| `checkblis` | Execute `testblis` and characterize the results to `stdout`. | +| `checkblis-fast`| Execute `testblis-fast` and characterize the results to `stdout`. | +| `checkblis-md` | Execute `testblis-md` and characterize the results to `stdout`. | +| `checkblis-salt`| Execute `testblis-salt` and characterize the results to `stdout`. | +| `checkblas` | Execute `testblas` and characterize the results to `stdout`. | +| `testblis` | Run the BLIS testsuite with default parameters (runs for 2-8 minutes). | +| `testblis-fast` | Run the BLIS testsuite with "fast" parameters (runs for a few seconds). | +| `testblis-md` | Run the BLIS testsuite for `gemm` with full mixing of datatypes (runs for 10-30 seconds). | +| `testblis-salt` | Run the BLIS testsuite while simulating application-level threading (runs for a few seconds). | +| `testsuite` | Same as `testblis`. | +| `testblas` | Run the BLAS test drivers with default parameters (runs for a few seconds). | +| `checkbliscpp` | Run the BLIS C++ tests (runs for a few seconds). | + +**_NOTE:_** +Using those targets sets the environment appropriately, so copying the input files and/or the DLL in case of Windows builds is not required. + +### Running the testsuites +* On Linux all targets can be build and run in `build_blis` directory. +* On Windows, when Visual Studio has been used as a generator, one can build and run the blis API related tests from `build_blis/testsuite` directory and blas API tests from `build_blis/blastest` directory. To build and run the BLIS C++ interface tests, execute the target `checkbliscpp` in `build_blis/vendor/testcpp` directory. The targets `check` and `test` can be used in `build_blis` directory. +* On Windows, if Visual Studio is used to build the library and tests, note that only the high level targets will appear. All targets are available to build from the command prompt. + +## Adding configurations + +The CMake system is designed to closely relate to the BLIS Make system. Assuming that a user has followed the steps in [Configuration How To](ConfigurationHowTo.md), adding the new configuration on the CMake system requires the following steps: +* Add a `make_defs.cmake` file which is equivalent to `make_defs.mk`. One can see `blis/config/zen/make_defs.cmake` and `blis/config/zen/make_defs.mk` for an example. +* Update `blis/CMakeLists.txt` to remove the error for the particular new configuration and to add the option in `set_property()` so that it appears in cmake-gui. + +## Some examples + +In this section we provide some examples for users that are familiar with the build system based in Makefiles and want to try the new CMake system. + +**_NOTE:_** +The CMake system generates the shared libraries by default. To build the static libraries, you need to specify the corresponding CMake variable below +``` +cmake .. -DBUILD_SHARED_LIBS=OFF -DBLIS_CONFIG_FAMILY=amdzen +``` +The same generated header `blis.h` can be used when using the library. + +For shared libraries on Windows, one can easily import the symbols by defining the macro `-DBLIS_EXPORT=__declspec(dllimport)` while building the application, +but this is not necessary if static data symbols and objects are not used. + +### Example 1: multi-threaded LP64 libraries for amdzen configuration using clang compiler + +* With configure script: +``` +CC=clang ./configure --enable-threading=openmp --int-size=32 --blas-int-size=32 amdzen +``` + +* With CMake on Linux: +``` +cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang -DENABLE_THREADING=openmp -DINT_SIZE=32 -DBLAS_INT_SIZE=32 -DBLIS_CONFIG_FAMILY=amdzen +``` + +* With CMake on Windows: +``` +cmake .. -G "Visual Studio 17 2022" -TClangCl -DENABLE_THREADING=openmp -DINT_SIZE=32 -DBLAS_INT_SIZE=32 -DBLIS_CONFIG_FAMILY=amdzen -DOpenMP_libomp_LIBRARY="path_to_openmp_library" +``` + +### Example 2: single-threaded ILP64 libraries for amdzen configuration with aocl_gemm addon enabled and default compiler + +**_NOTE:_** +Addon functionality is currently available only on Linux. + +* With configure script: +``` +./configure --enable-threading=no --int-size=64 --blas-int-size=64 --enable-addon=aocl_gemm amdzen +``` + +* With CMake on Linux: +``` +cmake .. -DENABLE_THREADING=no -DINT_SIZE=64 -DBLAS_INT_SIZE=64 -DENABLE_ADDON=aocl_gemm -DBLIS_CONFIG_FAMILY=amdzen +``` + +## Conclusion + +The BLIS CMake system is developed and maintained by AMD. You can contact us on the email-id toolchainsupport@amd.com. You can also raise any issue/suggestion on the git-hub repository at https://github.com/amd/blis/issues. \ No newline at end of file diff --git a/docs/FAQ.md b/docs/FAQ.md index 423009ae36..3ce8078c11 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -9,6 +9,7 @@ project, as well as those we think a new user or developer might ask. If you do * [Why should I use BLIS instead of GotoBLAS / OpenBLAS / ATLAS / MKL / ESSL / ACML / Accelerate?](FAQ.md#why-should-i-use-blis-instead-of-gotoblas--openblas--atlas--mkl--essl--acml--accelerate) * [How is BLIS related to FLAME / libflame?](FAQ.md#how-is-blis-related-to-flame--libflame) * [What is the difference between BLIS and the AMD fork of BLIS found in AOCL?](FAQ.md#what-is-the-difference-between-blis-and-the-amd-fork-of-blis-found-in-aocl) + * [Who do I contact if I have a question about the AMD version of BLIS?](FAQ.md#who-do-i-contact-if-i-have-a-question-about-the-amd-version-of-blis) * [Does BLIS automatically detect my hardware?](FAQ.md#does-blis-automatically-detect-my-hardware) * [I understand that BLIS is mostly a tool for developers?](FAQ.md#i-understand-that-blis-is-mostly-a-tool-for-developers) * [How do I link against BLIS?](FAQ.md#how-do-i-link-against-blis) @@ -17,6 +18,8 @@ project, as well as those we think a new user or developer might ask. If you do * [What is a macrokernel?](FAQ.md#what-is-a-macrokernel) * [What is a context?](FAQ.md#what-is-a-context) * [I am used to thinking in terms of column-major/row-major storage and leading dimensions. What is a "row stride" / "column stride"?](FAQ.md#im-used-to-thinking-in-terms-of-column-majorrow-major-storage-and-leading-dimensions-what-is-a-row-stride--column-stride) + * [I'm somewhat new to this matrix stuff. Can you remind me, what is the difference between a matrix row and a matrix column?](FAQ.md#im-somewhat-new-to-this-matrix-stuff-can-you-remind-me-what-is-the-difference-between-a-matrix-row-and-a-matrix-column) + * [Why does BLIS have vector (level-1v) and matrix (level-1m) variations of most level-1 operations?](FAQ.md#why-does-blis-have-vector-level-1v-and-matrix-level-1m-variations-of-most-level-1-operations) * [What does it mean when a matrix with general stride is column-tilted or row-tilted?](FAQ.md#what-does-it-mean-when-a-matrix-with-general-stride-is-column-tilted-or-row-tilted) * [I am not really interested in all of these newfangled features in BLIS. Can I just use BLIS as a BLAS library?](FAQ.md#im-not-really-interested-in-all-of-these-newfangled-features-in-blis-can-i-just-use-blis-as-a-blas-library) * [What about CBLAS?](FAQ.md#what-about-cblas) @@ -35,8 +38,7 @@ project, as well as those we think a new user or developer might ask. If you do * [Who funded the development of BLIS?](FAQ.md#who-funded-the-development-of-blis) * [I found a bug. How do I report it?](FAQ.md#i-found-a-bug-how-do-i-report-it) * [How do I request a new feature?](FAQ.md#how-do-i-request-a-new-feature) - * [What is the difference between this version of BLIS and the one that AMD maintains?](FAQ.md#what-is-the-difference-between-this-version-of-blis-and-the-one-that-amd-maintains) - * [Who do I contact if I have a question about the AMD version of BLIS?](FAQ.md#who-do-i-contact-if-i-have-a-question-about-the-amd-version-of-blis) + * [I'm a developer and I'd like to study the way matrix multiplication is implemented in BLIS. Where should I start?](FAQ.md#im-a-developer-and-id-like-to-study-the-way-matrix-multiplication-is-implemented-in-blis-where-should-i-start) * [Where did you get the photo for the BLIS logo / mascot?](FAQ.md#where-did-you-get-the-photo-for-the-blis-logo--mascot) ### Why did you create BLIS? @@ -59,7 +61,9 @@ homepage](https://github.com/flame/blis#key-features). But here are a few reason ### How is BLIS related to FLAME / `libflame`? -As explained [above](FAQ.md#why-did-you-create-blis?), BLIS was initially a layer within `libflame` that allowed more convenient interfacing to the BLAS. So in some ways, BLIS is a spin-off project. Prior to developing BLIS, [its author](http://www.cs.utexas.edu/users/field/) worked as the primary maintainer of `libflame`. If you look closely, you can also see that the design of BLIS was influenced by some of the more useful and innovative aspects of `libflame`, such as internal object abstractions and control trees. Also, various members of the [SHPC research group](http://shpc.ices.utexas.edu/people.html) and its [collaborators](http://shpc.ices.utexas.edu/collaborators.html) routinely provide insight, feedback, and also contribute code (especially kernels) to the BLIS project. +As explained [above](FAQ.md#why-did-you-create-blis?), BLIS was initially a layer within `libflame` that allowed more convenient interfacing to the BLAS. So in some ways, BLIS is a spin-off project. Prior to developing BLIS, [its primary author](http://www.cs.utexas.edu/users/field/) worked as the primary maintainer of `libflame`. If you look closely, you can also see that the design of BLIS was influenced by some of the more useful and innovative aspects of `libflame`, such as internal object abstractions and control trees. + +Note that various members of the [SHPC research group](http://shpc.ices.utexas.edu/people.html) and its [collaborators](http://shpc.ices.utexas.edu/collaborators.html) routinely provide insight, feedback, and also contribute code (especially kernels) to the BLIS project. ### What is the difference between BLIS and the AMD fork of BLIS found in AOCL? @@ -67,6 +71,10 @@ BLIS, also known as "vanilla BLIS" or "upstream BLIS," is maintained by its [ori AMD BLIS sometimes contains certain optimizations specific to AMD hardware. Many of these optimizations are (eventually) merged back into upstream BLIS. However, for various reasons, some changes may remain unique to AMD BLIS for quite some time. Thus, if you want the latest optimizations for AMD hardware, feel free to try AMD BLIS. However, please note that neither The University of Texas at Austin nor BLIS's developers can endorse or offer direct support for any outside fork of BLIS, including AMD BLIS. +### Who do I contact if I have a question about the AMD version of BLIS? + +For questions or support regarding [AMD's fork of BLIS](https://github.com/amd/blis), please contact the [AMD Optimizing CPU Libraries](https://developer.amd.com/amd-aocl/) group at aoclsupport@amd.com. + ### Does BLIS automatically detect my hardware? On certain architectures (most notably x86_64), yes. In order to use auto-detection, you must specify `auto` as your configuration when running `configure` (Please see the BLIS [Build System](BuildSystem.md) guide for more info.) A runtime detection option is also available. (Please see the [Configuration Guide](ConfigurationHowTo.md) for a comprehensive walkthrough.) @@ -75,9 +83,9 @@ If automatic hardware detection is requested at configure-time and the build pro ### I understand that BLIS is mostly a tool for developers? -Yes. In order to achieve high performance, BLIS requires that hand-coded kernels and microkernels be written and referenced in a valid [BLIS configuration](ConfigurationHowTo.md). These components are usually written by developers and then included within BLIS for use by others. +It is certainly the case that BLIS began as a tool targeted at developers. In order to achieve high performance, BLIS requires that hand-coded kernels and microkernels be written and referenced in a valid [BLIS configuration](ConfigurationHowTo.md). These components are usually written by developers and then included within BLIS for use by others. -The good news, however, is that end-users can use BLIS too. Once the aforementioned kernels are integrated into BLIS, they can be used without any developer-level knowledge, and many kernels have already been added! Usually, `./configure auto; make; make install` is sufficient for the typical users with typical hardware. +The good news, however, is that BLIS has matured to the point where end-users can use it too! Once the aforementioned kernels are integrated into BLIS, they can be used without any developer-level knowledge, and many kernels have already been added! Usually, `./configure auto; make; make install` is sufficient for the typical users with typical hardware. ### How do I link against BLIS? @@ -97,9 +105,9 @@ For a more thorough explanation of the microkernel and its role in the overall l ### What is a macrokernel? -The macrokernels are portable codes within the BLIS framework that implement relatively small subproblems within an overall level-3 operation. The overall problem (say, general matrix-matrix multiplication, or `gemm`) is partitioned down, according to cache blocksizes, such that its operands are (1) a suitable size and (2) stored in a special packed format. At that time, the macrokernel is called. The macrokernel is implemented as two loops around the microkernel. +The macrokernels are portable codes within the BLIS framework that implement relatively small subproblems within an overall level-3 operation. The overall problem (say, general matrix-matrix multiplication, or `gemm`) is partitioned down, according to cache blocksizes, such that its `A` and `B` operands are (1) a suitable size and (2) stored in a special packed format. At that time, the macrokernel is called. The macrokernel is implemented as two loops around the microkernel. -The macrokernels in BLIS correspond to the so-called "inner kernels" (or simply "kernels") that formed the fundamental unit of computation in Kazushige Goto's GotoBLAS (and now in the successor library, OpenBLAS). +The macrokernels, along with the microkernel that they call, correspond to the so-called "inner kernels" (or simply "kernels") that formed the fundamental unit of computation in Kazushige Goto's GotoBLAS (and now in the successor library, OpenBLAS). For more information on macrokernels, please read our [ACM TOMS papers](https://github.com/flame/blis#citations). @@ -117,13 +125,33 @@ In generalized storage, we have a row stride and a column stride. The row stride BLIS also supports situations where both the row stride and column stride are non-unit. We call this situation "general stride". +### I'm somewhat new to this matrix stuff. Can you remind me, what is the difference between a matrix row and a matrix column? + +Of course! (BLIS's primary author remembers what it was like to get columns and rows confused.) + +Matrix columns consist of elements that are vertically aligned. Matrix rows consist of elements that are horizontally aligned. (One way to remember this distinction is that real-life columns are vertical structures that hold up buildings. A row of seats in a stadium, by contrast, is horizontal to the ground.) + +Furthermore, it is helpful to know that the number of rows in a matrix constitutes its so-called *m* dimension, and the number of columns constitutes its *n* dimension. + +Matrix dimension are always stated as *m x n*: the number of rows *by* the number of columns. + +So, a *3 x 4* matrix contains three rows (each of length four) and four columns (each of length three). + +### Why does BLIS have vector (level-1v) and matrix (level-1m) variations of most level-1 operations? + +At first glance, it might appear that an element-wise operation such as `copym` or `axpym` would be sufficiently general purpose to cover the cases where the operands are vectors. After all, an *m x 1* matrix can be viewed as a vector of length m and vice versa. But in BLIS, operations on vectors are treated slightly differently than operations on matrices. + +If an application wishes to perform an element-wise operation on two objects, and the application calls a level-1m operation, the dimensions of those objects must be conformal, or "match up" (after any transposition implied by the object properties). This includes situations where one of the dimensions is unit. + +However, if an application instead decides to perform an element-wise operation on two objects, and the application calls a level-1v operation, the dimension constraints are slightly relaxed. In this scenario, BLIS only checks that the vector *lengths* are equal. This allows for the vectors to have different orientations (row vs column) while still being considered conformal. So, you could perform a `copyv` operation to copy from an *m x 1* vector to a *1 x m* vector. A `copym` operation on such objects would not be allowed (unless it was executed with the source object containing an implicit transposition). + ### What does it mean when a matrix with general stride is column-tilted or row-tilted? When a matrix is stored with general stride, both the row stride and column stride (let's call them `rs` and `cs`) are non-unit. When `rs` < `cs`, we call the general stride matrix "column-tilted" because it is "closer" to being column-stored (than row-stored). Similarly, when `rs` > `cs`, the matrix is "row-tilted" because it is closer to being row-stored. ### I'm not really interested in all of these newfangled features in BLIS. Can I just use BLIS as a BLAS library? -Absolutely. Just link your application to BLIS the same way you would link to a BLAS library. For a simple linking example, see the [Linking to BLIS](KernelsHowTo.md#linking-to-blis) section of the BLIS [Build System](BuildSystem.md) guide. +Absolutely! Just link your application to BLIS the same way you would link to a BLAS library. For a simple linking example, see the [Linking to BLIS](KernelsHowTo.md#linking-to-blis) section of the BLIS [Build System](BuildSystem.md) guide. ### What about CBLAS? @@ -133,11 +161,13 @@ BLIS also contains an optional CBLAS compatibility layer, which leverages the BL In principle, BLIS's native (and BLAS-like) [typed API](BLISTypedAPI) can be called from Fortran. However, you must ensure that the size of the integer in BLIS is equal to the size of integer used by your Fortran program/compiler/environment. The size of BLIS integers is determined at configure-time. Please see `./configure --help` for the syntax for options related to integer sizes. +You may also want to confirm that your Fortran compiler doesn't perform any name-mangling of called functions or subroutines (such as with additional underscores beyond the single trailing underscore found in the BLAS APIs), and if so, take steps to disable this additional name-mangling. For example, if your source code calls `dgemm()` but your Fortran compiler name-mangles that call to `_dgemm_()` or `dgemm__()`, your program will fail to link against BLIS since BLIS only defines `dgemm_()`. + As for bindings to other languages, please contact the [blis-devel](http://groups.google.com/group/blis-devel) mailing list. ### Do I need to call initialization/finalization functions before being able to use BLIS from my application? -Originally, BLIS did indeed require the application to explicitly setup (initialize) various internal data structures via `bli_init()`. Likewise, calling `bli_finalize()` was recommended to cleanup (finalize) the library. However, since commit 9804adf (circa December 2017), BLIS has implemented self-initialization. These explicit calls to `bli_init()` and `bli_finalize()` are no longer necessary, though experts may still use them in special cases to control the allocation and freeing of resources. This topic is discussed in the BLIS [typed API reference](BLISTypedAPI.md#initialization-and-cleanup). +Originally, BLIS did indeed require the application to explicitly setup (initialize) various internal data structures via `bli_init()`. Likewise, calling `bli_finalize()` was recommended to cleanup (finalize) the library. However, since commit `9804adf` (circa December 2017), BLIS has implemented self-initialization. These explicit calls to `bli_init()` and `bli_finalize()` are no longer necessary, though experts may still use them in special cases to control the allocation and freeing of resources. This topic is discussed in the BLIS [typed API reference](BLISTypedAPI.md#initialization-and-cleanup). ### Does BLIS support multithreading? @@ -151,7 +181,7 @@ We have integrated some early foundational support for NUMA *development*, but c ### Does BLIS work with GPUs? -BLIS does not currently support graphical processing units (GPUs). However, others have applied the BLIS approach towards frameworks that provide BLAS-like functionality on GPUs. To see how NVIDIA's implementation compares to an analagous approach based on the principles that underlie BLIS, please see a paper by some of our collaborators, ["Implementing Strassen’s Algorithm with CUTLASSon NVIDIA Volta GPUs"](https://apps.cs.utexas.edu/apps/sites/default/files/tech_reports/GPUStrassen.pdf). +BLIS does not currently support graphical processing units (GPUs). However, others have applied the BLIS approach towards frameworks that provide BLAS-like functionality on GPUs. To see how NVIDIA's implementation compares to an analogous approach based on the principles that underlie BLIS, please see a paper by some of our collaborators, ["Implementing Strassen’s Algorithm with CUTLASS on NVIDIA Volta GPUs"](https://apps.cs.utexas.edu/apps/sites/default/files/tech_reports/GPUStrassen.pdf). ### Does BLIS work on _(some architecture)_? @@ -163,7 +193,7 @@ No. BLIS is a framework for sequential and shared-memory/multicore implementatio ### Can I build BLIS on Mac OS X? -BLIS was designed for use in a GNU/Linux environment. However, we've gone to greath lengths to keep BLIS compatible with other UNIX-like systems as well, such as BSD and OS X. System software requirements for UNIX-like systems are discussed in the BLIS [Build System](BuildSystem.md) guide. +BLIS was designed for use in a GNU/Linux environment. However, we've gone to great lengths to keep BLIS compatible with other UNIX-like systems as well, such as BSD and OS X. System software requirements for UNIX-like systems are discussed in the BLIS [Build System](BuildSystem.md) guide. ### Can I build BLIS on Windows? @@ -192,7 +222,7 @@ Yes. By default, most configurations output only a static library archive (e.g. ### Can I use the mixed domain / mixed precision support in BLIS? -Yes! As of 5fec95b (circa October 2018), BLIS supports mixed-datatype (mixed domain and/or mixed precision) computation via the `gemm` operation. Documentation on utilizing this new functionality is provided via the [MixedDatatype.md](docs/MixedDatatypes.md) document in the source distribution. +Yes! As of 5fec95b (circa October 2018), BLIS supports mixed-datatype (mixed domain and/or mixed precision) computation via the `gemm` operation. Documentation on utilizing this new functionality is provided via the [MixedDatatype.md](MixedDatatypes.md) document in the source distribution. If this feature is important or useful to your work, we would love to hear from you. Please contact us via the [blis-devel](http://groups.google.com/group/blis-devel) mailing list and tell us about your application and why you need/want support for BLAS-like operations with mixed-domain/mixed-precision operands. @@ -203,33 +233,27 @@ Lots of people! For a full list of those involved, see the ### Who funded the development of BLIS? -BLIS was primarily funded by grants from [Microsoft](https://www.microsoft.com/), -[Intel](https://www.intel.com/), [Texas -Instruments](https://www.ti.com/), [AMD](https://www.amd.com/), [Huawei](https://www.hauwei.com/us/), [Oracle](https://www.oracle.com/), and [Facebook](https://www.facebook.com/) as well as grants from the [National Science Foundation](http://www.nsf.gov/) (Awards CCF-0917167 ACI-1148125/1340293, and CCF-1320112). +BLIS was primarily funded by a variety of gifts/grants from industry and the National Science Foundation. Please see the "Funding" section of the [BLIS homepage](https://github.com/flame/blis#funding) for more details. Reminder: _Any opinions, findings and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation (NSF)._ ### I found a bug. How do I report it? -If you think you've found a bug, we request that you [open an issue](http://github.com/flame/blis/issues). Don't be shy! Really, it's the best and most convenient way for us to track your issues/bugs/concerns. Other discussions that are not primarily bug-reports should take place via the [blis-devel](http://groups.google.com/group/blis-devel) mailing list. +If you think you've found a bug, we request that you [open an issue](http://github.com/flame/blis/issues). Don't be shy! Really, it's the best and most convenient way for us to track your issues/bugs/concerns. ### How do I request a new feature? Feature requests should also be submitted by [opening a new issue](http://github.com/flame/blis/issues). -### What is the difference between this version of BLIS and the one that AMD maintains? - -AMD has chosen BLIS as the open-source foundation for the BLAS component of their [AMD Optimizing CPU Libraries (AOCL)](https://developer.amd.com/amd-aocl/) toolkit. Our group enjoys a great collaboration and partnership with AMD, and we are pleased to have their enthusiastic support for our project. +### I'm a developer and I'd like to study the way matrix multiplication is implemented in BLIS. Where should I start? -At a technical level, AMD's fork of BLIS is considered to be a downstream variant. AMD uses their fork to develop optimizations specific to AMD hardware. Occasionally, AMD will submit pull requests to merge their features, enhancements, and fixes back into our "plain vanilla" upstream repository. So our upstream BLIS will eventually contain most of the modifications originally developed by AMD in their fork, but with a lag. Similarly, features introduced into the upstream BLIS may not be immediately available in AMD's fork, but eventually their team will perform a merge and synchronize with our latest code. +Great question! The first thing you should know is that the core framework of [level-3 operations](BLISTypedAPI.md#operation-index) was *not* designed to be used to teach or explain a high-performance implementation of matrix multiplication. Rather, it was designed to encode the family of level-3 operations with as little code duplication as possible. Because of this, and also for historical/evolutionary reasons, it can be a little difficult to trace the execution of, say, `gemm` from within the core framework. -AMD also uses a different versioning system for AOCL which is independent of the versions used by the [upstream BLIS](http://github.com/flame/blis) project. +Thankfully, we have an alternative environment in which experts, application developers, and other curious individuals can study BLIS's matrix multiplication implementation. This so-called "sandbox" is a simplified collection of code that strips away much of the framework complexity while also maintaining local definitions for many of the interesting bits. You may find this `gemmlike` sandbox in `sandbox/gemmlike`. -### Who do I contact if I have a question about the AMD version of BLIS? - -For questions or support regarding [AMD's fork of BLIS](https://github.com/amd/blis), please contact the [AMD Optimizing CPU Libraries](https://developer.amd.com/amd-aocl/) group at aoclsupport@amd.com. +Sandboxes go beyond the scope of this FAQ. For an introduction, please refer to the [Sandboxes](Sandboxes.md) document, and/or contact the BLIS developers for more information. ### Where did you get the photo for the BLIS logo / mascot? -The sleeping ["BLIS cat"](https://github.com/flame/blis/blob/master/README.md) photo was taken by Petar Mitchev and is used with his permission. +The sleeping ["BLIS cat"](README.md) photo was taken by Petar Mitchev and is used with his permission. diff --git a/docs/Performance.md b/docs/Performance.md index 0a296c12a7..051be7aea9 100644 --- a/docs/Performance.md +++ b/docs/Performance.md @@ -24,6 +24,9 @@ * **[A64fx](Performance.md#a64fx)** * **[Experiment details](Performance.md#a64fx-experiment-details)** * **[Results](Performance.md#a64fx-results)** + * **[Neoverse N1](Performance.md#neoverse-n1)** + * **[Experiment details](Performance.md#neoverse-n1-experiment-details)** + * **[Results](Performance.md#neoverse-n1-results)** * **[Feedback](Performance.md#feedback)** # Introduction @@ -534,7 +537,7 @@ The `runthese.m` file will contain example invocations of the function. ### A64fx experiment details * Location: RIKEN Center of Computational Science in Kobe, Japan - * These test results were gathered on the Fugaku supercomputer under project "量子物質の創発と機能のための基礎科学 ―「富岳」と最先端実験の密連携による革新的強相関電子科学" (hp200132) + * These test results were gathered on the Fugaku supercomputer under project "量子物質の創発と機能のための基礎科学 ―「富岳」と最先端実験の密連携による革新的強相関電子科学" (hp200132) (Basic Science for Emergence and Functionality in Quantum Matter: Innovative Strongly-Correlated Electron Science by Integration of "Fugaku" and Frontier Experiments) * Processor model: Fujitsu A64fx * Core topology: one socket, 4 NUMA groups per socket, 13 cores per group (one reserved for the OS), 48 cores total * SMT status: Unknown @@ -546,23 +549,17 @@ The `runthese.m` file will contain example invocations of the function. * multicore: 70.4 GFLOPS/core (double-precision), 140.8 GFLOPS/core (single-precision) * Operating system: RHEL 8.3 * Page size: 256 bytes -* Compiler: gcc 9.3.0 -* Results gathered: 2 April 2021 +* Compiler: gcc 10.1.0 +* Results gathered: 2 April 2021; BLIS and SSL2 updated on 20 May 2021 * Implementations tested: - * BLIS 757cb1c (post-0.8.1) - * configured with `./configure -t openmp --sve-vector-size=vla CFLAGS="-D_A64FX -DPREFETCH256 -DSVE_NO_NAT_COMPLEX_KERNELS" arm64_sve` (single- and multithreaded) - * sub-configuration exercised: `arm64_sve` - * Single-threaded (1 core) execution requested via: - * `export BLIS_SVE_KC_D=2048 BLIS_SVE_MC_D=128 BLIS_SVE_NC_D=26880 BLIS_SVE_KERNEL_IDX_D=14` (double precision) - * `export BLIS_SVE_KC_S=2048 BLIS_SVE_MC_S=256 BLIS_SVE_NC_S=23040 BLIS_SVE_KERNEL_IDX_S=2` (single precision) - * Multithreaded (12 core) execution requested via: - * `export BLIS_JC_NT=1 BLIS_IC_NT=2 BLIS_JR_NT=6` - * `export BLIS_SVE_KC_D=2400 BLIS_SVE_MC_D=64 BLIS_SVE_NC_D=26880 BLIS_SVE_KERNEL_IDX_D=14` (double precision) - * `export BLIS_SVE_KC_S=2400 BLIS_SVE_MC_S=128 BLIS_SVE_NC_S=23040 BLIS_SVE_KERNEL_IDX_S=2` (single precision) - * Multithreaded (48 core) execution requested via: - * `export BLIS_JC_NT=1 BLIS_IC_NT=4 BLIS_JR_NT=12` - * `export BLIS_SVE_KC_D=2048 BLIS_SVE_MC_D=128 BLIS_SVE_NC_D=26880 BLIS_SVE_KERNEL_IDX_D=14` (double precision) - * `export BLIS_SVE_KC_S=2048 BLIS_SVE_MC_S=256 BLIS_SVE_NC_S=23040 BLIS_SVE_KERNEL_IDX_S=2` (single precision) + * BLIS 61584de (post-0.8.1) + * configured with: + * `../configure -t none CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (single-threaded) + * `../configure -t openmp CFLAGS="-DCACHE_SECTOR_SIZE_READONLY" a64fx` (multithreaded) + * sub-configuration exercised: `a64fx` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (12 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=1 BLIS_JR_NT=12` + * Multithreaded (48 core) execution requested via `export BLIS_JC_NT=1 BLIS_IC_NT=4 BLIS_JR_NT=12` * Eigen 3.3.9 * Obtained via the [Eigen GitLab homepage](https://gitlab.com/libeigen/eigen) * configured and built BLAS library via `mkdir build; cd build; cmake ..; make blas` @@ -593,7 +590,7 @@ The `runthese.m` file will contain example invocations of the function. #### pdf * [A64fx single-threaded](graphs/large/l3_perf_a64fx_nt1.pdf) -* [A64fx multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.pdf) +* [A64fx multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf) * [A64fx multithreaded (48 cores)](graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf) #### png (inline) @@ -601,12 +598,64 @@ The `runthese.m` file will contain example invocations of the function. * **A64fx single-threaded** ![single-threaded](graphs/large/l3_perf_a64fx_nt1.png) * **A64fx multithreaded (12 cores)** -![multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.png) +![multithreaded (12 cores)](graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png) * **A64fx multithreaded (48 cores)** ![multithreaded (48 cores)](graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png) --- +## Neoverse N1 + +### Neoverse N1 experiment details + +* Location: AWS cloud +* Processor model: Graviton2 Neoverse N1 +* Core topology: one socket, 64 cores per socket, 64 cores total +* SMT status: none +* Max clock rate: 2.5GHz (single-core and multicore) +* Max vector register length: 128 bits (NEON) +* Max FMA vector IPC: 2 +* Peak performance: + * single-core: 20.0 GFLOPS (double-precision), 40.0 GFLOPS (single-precision) + * multicore: 20.0 GFLOPS/core (double-precision), 40.0 GFLOPS/core (single-precision) +* Operating system: unknown +* Page size: unknown +* Compiler: gcc 10.3.0 +* Results gathered: 15 July 2021 +* Implementations tested: + * BLIS fab5c86d (0.8.1-67) + * configured with `./configure -t openmp thunderx2` (single- and multithreaded) + * sub-configuration exercised: `thunderx2` + * Single-threaded (1 core) execution requested via no change in environment variables + * Multithreaded (64 core) execution requested via `export BLIS_NUM_THREADS=64` + * OpenBLAS 0.3.17 + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=0` (single-threaded) + * configured `Makefile.rule` with `BINARY=64 NO_CBLAS=1 NO_LAPACK=1 NO_LAPACKE=1 USE_THREAD=1 NUM_THREADS=64` (multithreaded, 64 cores) + * Single-threaded (1 core) execution requested via `export OPENBLAS_NUM_THREADS=1` + * Multithreaded (64 core) execution requested via `export OPENBLAS_NUM_THREADS=64` +* Affinity: + * Thread affinity for BLIS was specified manually via `GOMP_CPU_AFFINITY="0-63"`. However, multithreaded OpenBLAS appears to revert to single-threaded execution if `GOMP_CPU_AFFINITY` is set. Therefore, when measuring OpenBLAS performance, the `GOMP_CPU_AFFINITY` environment variable was unset. +* Frequency throttling (via `cpupower`): + * No changes made. +* Comments: + * N/A + +### Neoverse N1 results + +#### pdf + +* [Neoverse N1 single-threaded](graphs/large/l3_perf_nn1_nt1.pdf) +* [Neoverse N1 multithreaded (64 cores)](graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf) + +#### png (inline) + +* **Neoverse N1 single-threaded** +![single-threaded](graphs/large/l3_perf_nn1_nt1.png) +* **Neoverse N1 multithreaded (64 cores)** +![multithreaded (64 cores)](graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png) + +--- + # Feedback Please let us know what you think of these performance results! Similarly, if you have any questions or concerns, or are interested in reproducing these performance experiments on your own hardware, we invite you to [open an issue](https://github.com/flame/blis/issues) and start a conversation with BLIS developers. diff --git a/docs/Sandboxes.md b/docs/Sandboxes.md index ce1548f6e0..8f404d0a6b 100644 --- a/docs/Sandboxes.md +++ b/docs/Sandboxes.md @@ -37,11 +37,11 @@ utility functions. To enable a sandbox at configure-time, you simply specify it as an option to `configure`. Either of the following usages are accepted: ``` -$ ./configure --enable-sandbox=ref99 auto -$ ./configure -s ref99 auto +$ ./configure --enable-sandbox=gemmlike auto +$ ./configure -s gemmlike auto ``` -Here, we tell `configure` that we want to use the `ref99` sandbox, which -corresponds to a sub-directory of `sandbox` named `ref99`. (Reminder: the +Here, we tell `configure` that we want to use the `gemmlike` sandbox, which +corresponds to a sub-directory of `sandbox` named `gemmlike`. (Reminder: the `auto` argument is the configuration target and thus unrelated to sandboxes.) @@ -50,7 +50,7 @@ sizes and shapes, you'll need to disable the skinny/unpacked "sup" sub-framework within BLIS, which is enabled by default. This can be done by passing the `--disable-sup-handling` option to configure: ``` -$ ./configure --enable-sandbox=ref99 --disable-sup-handling auto +$ ./configure --enable-sandbox=gemmlike --disable-sup-handling auto ``` If you leave sup enabled, the sup implementation will, at runtime, detect and handle certain smaller problem sizes upstream of where BLIS calls @@ -62,13 +62,14 @@ As `configure` runs, you should get output that includes lines similar to: ``` configure: configuring for alternate gemm implementation: -configure: sandbox/ref99 +configure: sandbox/gemmlike ``` And when you build BLIS, the last files to be compiled will be the source code in the specified sandbox: ``` -Compiling obj/haswell/sandbox/ref99/blx_gemm_ref_var2.o ('haswell' CFLAGS for sandboxes) -Compiling obj/haswell/sandbox/ref99/oapi/bli_gemmnat.o ('haswell' CFLAGS for sandboxes) +Compiling obj/haswell/sandbox/gemmlike/bli_gemmnat.o ('haswell' CFLAGS for sandboxes) +Compiling obj/haswell/sandbox/gemmlike/bls_gemm.o ('haswell' CFLAGS for sandboxes) +Compiling obj/haswell/sandbox/gemmlike/bls_gemm_bp_var1.o ('haswell' CFLAGS for sandboxes) ... ``` That's it! After the BLIS library is built, it will contain your chosen @@ -92,16 +93,19 @@ will be found! 2. Your sandbox must be written in C99 or C++11. If you write your sandbox in C++11, you must use one of the BLIS-approved file extensions for your source files (`.cc`, `.cpp`, `.cxx`) and your header files (`.hh`, `.hpp`, `.hxx`). -Note that `blis.h` -already contains all of its definitions inside of an `extern "C"` block, so -you should be able to `#include "blis.h"` from your C++11 source code without -any issues. +Note that `blis.h` already contains all of its definitions inside of an +`extern "C"` block, so you should be able to `#include "blis.h"` from your +C++11 source code without any issues. 3. All of your code to replace BLIS's default implementation of `bli_gemmnat()` should reside in the named sandbox directory, or some directory therein. -(Obviously.) For example, the "reference" sandbox is located in -`sandbox/ref99`. All of the code associated with this sandbox will be -contained within `sandbox/ref99`. +(Obviously.) For example, the "gemmlike" sandbox is located in +`sandbox/gemmlike`. All of the code associated with this sandbox will be +contained within `sandbox/gemmlike`. Note that you absolutely *may* include +additional code and interfaces within the sandbox, if you wish -- code and +interfaces that are not directly or indirectly needed for satisfying the +the "contract" set forth by the sandbox (i.e., including a local definition +of`bli_gemmnat()`). 4. The *only* header file that is required of your sandbox is `bli_sandbox.h`. It must be named `bli_sandbox.h` because `blis.h` will `#include` this file @@ -116,16 +120,17 @@ you should only place things (e.g. prototypes or type definitions) in Usually, neither of these situations will require any of your local definitions since those local definitions are only needed to define your sandbox implementation of `bli_gemmnat()`, and this function is already prototyped by -BLIS. +BLIS. *But if you are adding additional APIs and/or operations to the sandbox +that are unrelated to `bli_gemmnat()`, then you'll want to `#include` those +function prototypes from within `bli_sandbox.h`* 5. Your definition of `bli_gemmnat()` should be the **only function you define** in your sandbox that begins with `bli_`. If you define other functions that begin with `bli_`, you risk a namespace collision with existing framework functions. To guarantee safety, please prefix your locally-defined sandbox -functions with another prefix. Here, in the `ref99` sandbox, we use the prefix -`blx_`. (The `x` is for sandbox. Or experimental.) Also, please avoid the -prefix `bla_` since that prefix is also used in BLIS for BLAS compatibility -functions. +functions with another prefix. Here, in the `gemmlike` sandbox, we use the prefix +`bls_`. (The `s` is for sandbox.) Also, please avoid the prefix `bla_` since that +prefix is also used in BLIS for BLAS compatibility functions. If you follow these rules, you will be much more likely to have a pleasant experience integrating your BLIS sandbox into the larger framework. @@ -207,15 +212,9 @@ enabled in `input.general`. However, if those options *are* enabled and BLIS was built with mixed datatype support, then BLIS assumes that the implementation of `gemm` will support mixing of datatypes. BLIS *must* assume this, because there's no way for it to confirm at runtime that an implementation was written -to support mixing datatypes. Note that even the `ref99` sandbox included with +to support mixing datatypes. Note that even the `gemmlike` sandbox included with BLIS does not support mixed-datatype computation. -* **Multithreading in ref99.** The current reference sandbox, `ref99`, does not -currently implement multithreading. - -* **Packing matrices in ref99.** The current reference sandbox, `ref99`, does not -currently implement packing of matrices A or B. - ## Conclusion If you encounter any problems, or are really bummed-out that `gemm` is the diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf new file mode 100644 index 0000000000..e273d1d098 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png new file mode 100644 index 0000000000..1316647d65 Binary files /dev/null and b/docs/graphs/large/l3_perf_a64fx_jc1ic1jr12_nt12.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.pdf deleted file mode 100644 index 6802a39008..0000000000 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.pdf and /dev/null differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.png b/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.png deleted file mode 100644 index b55765a8f5..0000000000 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic2jr6_nt12.png and /dev/null differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf index 3249a9acf8..b311e0f5db 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png index 6841f3e623..c2719da87a 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png and b/docs/graphs/large/l3_perf_a64fx_jc1ic4jr12_nt48.png differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.pdf b/docs/graphs/large/l3_perf_a64fx_nt1.pdf index bce34bdb2e..6f0b8c74fc 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_nt1.pdf and b/docs/graphs/large/l3_perf_a64fx_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_a64fx_nt1.png b/docs/graphs/large/l3_perf_a64fx_nt1.png index 6d13b1c900..f2cb381786 100644 Binary files a/docs/graphs/large/l3_perf_a64fx_nt1.png and b/docs/graphs/large/l3_perf_a64fx_nt1.png differ diff --git a/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf new file mode 100644 index 0000000000..517aee9ed1 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.pdf differ diff --git a/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png new file mode 100644 index 0000000000..c77159dd5a Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_jc2ic8jr4_nt64.png differ diff --git a/docs/graphs/large/l3_perf_nn1_nt1.pdf b/docs/graphs/large/l3_perf_nn1_nt1.pdf new file mode 100644 index 0000000000..6c5ff9f063 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_nt1.pdf differ diff --git a/docs/graphs/large/l3_perf_nn1_nt1.png b/docs/graphs/large/l3_perf_nn1_nt1.png new file mode 100644 index 0000000000..750ccf0997 Binary files /dev/null and b/docs/graphs/large/l3_perf_nn1_nt1.png differ diff --git a/docs/styling/footer.html b/docs/styling/footer.html index d68520e1e9..160e30530e 100644 --- a/docs/styling/footer.html +++ b/docs/styling/footer.html @@ -1,5 +1,5 @@ L8 + Fringe loops : In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + + For non-unit strides : A single loop, to process element wise. + NOTE : Any size, requiring the fringe case of 1 with unit stride falls to + the non-unit stride loop and executes it once for just the last element. + + With regards to exception value testing, every loop is tested separately. + The indices for setting exception values on the vectors are such that + every load associated with the loop has an exception value in it. Thus, + every arithmetic instruction associated with each load will be tested + for exception value handling. +*/ + +// Exception value testing(on vectors) for L8 +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_vec_L8, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(8)), // m, size of vector to enter L8 directly. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(4), gtint_t(7)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(5), gtint_t(6)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on vectors) for L6 +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_vec_L6, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(6)), // m, size of vector to enter L8 directly. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(3), gtint_t(4)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(2), gtint_t(5)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on vectors) for L4 +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_vec_L4, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(4)), // m, size of vector to enter L8 directly. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(1), gtint_t(3)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on x + ::testing::Values(gtint_t(0), gtint_t(2)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on vectors) for L2 +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_vec_L2, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(2)), // m, size of vector to enter L8 directly. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(1)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on x + ::testing::Values(gtint_t(0)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}, + dcomplex{NaN, -Inf}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on vectors) with non unit strides +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_vec_NUS, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(1), gtint_t(5)), // m, size of vector to enter NUS loop directly. + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(-4)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}), // exception values to set on x + ::testing::Values(gtint_t(0)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{NaN, 2.3}, + dcomplex{-Inf, 0.0}, dcomplex{Inf, NaN}), // exception values to set on y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{0.9, 4.5}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on alpha/beta) with unit stride +/* + NOTE : Here, every loop is tested for, with alpha and beta having exception values + Furthermore, the first element of x and second element of y are set to 0, which + includes testing that cover cases where NaN might be induced due to 0 * (Inf or -Inf). +*/ +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_alphabeta_US, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(8), gtint_t(6), gtint_t(4), gtint_t(2)), // m size of vector to enter L8, L6, L4 and L2 respectively. + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on x + ::testing::Values(gtint_t(1)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on y + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, dcomplex{-Inf, NaN}), // alpha + ::testing::Values(dcomplex{-0.9, NaN}, dcomplex{0.0, -Inf}, dcomplex{NaN, Inf}) // beta + ), + ::zaxpbyvEVTVecPrint()); + +// Exception value testing(on alpha/beta) with non-unit stride +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_evt_alphabeta_NUS, + zaxpbyvEVTTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(5)), // m, size of vector to enter NUS loop directly. + ::testing::Values(gtint_t(3)), // stride size for x + ::testing::Values(gtint_t(-4)), // stride size for y + ::testing::Values(gtint_t(0)), // indices to set exception values on x + ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on x + ::testing::Values(gtint_t(0)), // indices to set exception values on y + ::testing::Values(dcomplex{0.0, 0.0}), // exception values to set on y + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, dcomplex{-Inf, NaN}), // alpha + ::testing::Values(dcomplex{-0.9, NaN}, dcomplex{0.0, -Inf}, dcomplex{NaN, Inf}) // beta + ), + ::zaxpbyvEVTVecPrint()); diff --git a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp index 690b7d4784..b69a132796 100644 --- a/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpbyv/zaxpbyv_generic.cpp @@ -9,14 +9,14 @@ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -35,16 +35,15 @@ #include #include "test_axpbyv.h" -class zaxpbyvGenericTest : +class zaxpbyvAccTest : public ::testing::TestWithParam> {}; + dcomplex>> {}; // Tests using random integers as vector elements. -TEST_P( zaxpbyvGenericTest, RandomData ) +TEST_P(zaxpbyvAccTest, RandomData) { using T = dcomplex; //---------------------------------------------------------- @@ -63,115 +62,140 @@ TEST_P( zaxpbyvGenericTest, RandomData ) T alpha = std::get<4>(GetParam()); // beta T beta = std::get<5>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<6>(GetParam()); // Set the threshold for the errors: - double thresh = 2*testinghelpers::getEpsilon(); + double thresh = 20 * testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_axpbyv(conj_x, n, incx, incy, alpha, beta, thresh, datatype); + test_axpbyv(conj_x, n, incx, incy, alpha, beta, thresh); } // Used to generate a test case with a sensible name. // Beware that we cannot use fp numbers (e.g., 2.3) in the names, // so we are only printing int(2.3). This should be enough for debugging purposes. // If this poses an issue, please reach out. -class zaxpbyvGenericTestPrint { +class zaxpbyvAccTestPrint +{ public: std::string operator()( - testing::TestParamInfo> str) const { - char conj = std::get<0>(str.param); - gtint_t n = std::get<1>(str.param); - gtint_t incx = std::get<2>(str.param); - gtint_t incy = std::get<3>(str.param); + testing::TestParamInfo> str) const + { + char conj = std::get<0>(str.param); + gtint_t n = std::get<1>(str.param); + gtint_t incx = std::get<2>(str.param); + gtint_t incy = std::get<3>(str.param); dcomplex alpha = std::get<4>(str.param); - dcomplex beta = std::get<5>(str.param); - char datatype = std::get<6>(str.param); + dcomplex beta = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "zaxpby_"; #elif TEST_CBLAS std::string str_name = "cblas_zaxpby"; -#else //#elif TEST_BLIS_TYPED +#else // #elif TEST_BLIS_TYPED std::string str_name = "bli_zaxpbyv"; #endif str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + std::string incx_str = (incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name += "_" + incx_str; - std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); + std::string incy_str = (incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); + std::string alpha_str = (alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); + alpha_str = alpha_str + "pi" + ((alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); + std::string beta_str = (beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); + beta_str = beta_str + "pi" + ((beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; -// Black box testing for generic and main use of zaxpby. +/* + The code structure for bli_zaxpbyv_zen_int( ... ) is as follows : + For unit strides : + Main loop : In blocks of 8 --> L8 + Fringe loops : In blocks of 6 --> L6 + In blocks of 4 --> L4 + In blocks of 2 --> L2 + + For non-unit strides : A single loop, to process element wise. + NOTE : Any size, requiring the fringe case of 1 with unit stride falls to + the non-unit stride loop and executes it once for just the last element. +*/ + +// Accuracy testing of the main loop, single and multiple runs INSTANTIATE_TEST_SUITE_P( - Blackbox, - zaxpbyvGenericTest, - ::testing::Combine( - ::testing::Values('n' // n: use x, c: use conj(x) + bli_zaxpbyv_zen_int_acc_US_main, + zaxpbyvAccTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , + 'c' // this option is BLIS-api specific. #endif - ), - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(dcomplex{-3.0, 1.0}, dcomplex{1.0, 2.0}), // alpha - ::testing::Values(dcomplex{1.0, 2.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ), + ::testing::Values(gtint_t(8), gtint_t(40)), // m + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvGenericTestPrint() - ); + ::zaxpbyvAccTestPrint()); -// Test for non-unit increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. +// Accuracy testing of different combinations of fringe loops(L6, L4, L2, 1) INSTANTIATE_TEST_SUITE_P( - NonUnitPositiveIncrements, - zaxpbyvGenericTest, - ::testing::Combine( - ::testing::Values('n' + bli_zaxpbyv_zen_int_acc_US_fringe, + zaxpbyvAccTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) #ifdef TEST_BLIS_TYPED - , 'c' // this option is BLIS-api specific. + , + 'c' // this option is BLIS-api specific. #endif - ), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(2)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x - ::testing::Values(gtint_t(4)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(dcomplex{4.0, 3.1}), // alpha - ::testing::Values(dcomplex{1.0, 2.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ), + ::testing::Range(gtint_t(1), gtint_t(7), 1), // m + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvGenericTestPrint() - ); + ::zaxpbyvAccTestPrint()); -#ifndef TEST_BLIS_TYPED -// Test for negative increments. -// Only test very few cases as sanity check. -// We can modify the values using implementantion details. +// Accuracy testing of 3*L8 + L6 + L4 + L2 + 1, a case of main + all fringe cases taken INSTANTIATE_TEST_SUITE_P( - NegativeIncrements, - zaxpbyvGenericTest, - ::testing::Combine( - ::testing::Values('n'), // n: use x, c: use conj(x) - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(11), gtint_t(-11)), // stride size for x - ::testing::Values(gtint_t(-3), gtint_t(4)), // stride size for y - ::testing::Values(dcomplex{4.0, 3.1}), // alpha - ::testing::Values(dcomplex{1.0, -2.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + bli_zaxpbyv_zen_int_acc_US_combine, + zaxpbyvAccTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. +#endif + ), + ::testing::Values(gtint_t(30), gtint_t(34), gtint_t(36), gtint_t(37)), // m + ::testing::Values(gtint_t(1)), // stride size for x + ::testing::Values(gtint_t(1)), // stride size for y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta ), - ::zaxpbyvGenericTestPrint() - ); + ::zaxpbyvAccTestPrint()); + +// Accuracy testing with non-unit strides +INSTANTIATE_TEST_SUITE_P( + bli_zaxpbyv_zen_int_acc_NUS, + zaxpbyvAccTest, + ::testing::Combine( + ::testing::Values('n' // n: use x, c: use conj(x) +#ifdef TEST_BLIS_TYPED + , + 'c' // this option is BLIS-api specific. #endif + ), + ::testing::Values(gtint_t(10), gtint_t(17)), // m + ::testing::Values(gtint_t(-3), gtint_t(4)), // stride size for x + ::testing::Values(gtint_t(6), gtint_t(-2)), // stride size for y + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{2.2, -3.3}), // alpha + ::testing::Values(dcomplex{0.0, 0.0}, dcomplex{1.0, 0.0}, dcomplex{1.0, 2.0}) // beta + ), + ::zaxpbyvAccTestPrint()); diff --git a/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp index 77cd26c285..ad4db3c95b 100644 --- a/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/caxpyv_generic.cpp @@ -40,8 +40,7 @@ class caxpyvGenericTest : gtint_t, gtint_t, gtint_t, - scomplex, - char>> {}; + scomplex>> {}; // Tests using random integers as vector elements. TEST_P( caxpyvGenericTest, RandomData ) { @@ -60,8 +59,6 @@ TEST_P( caxpyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // alpha T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*testinghelpers::getEpsilon(); @@ -69,7 +66,7 @@ TEST_P( caxpyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_axpyv(conj_x, n, incx, incy, alpha, thresh, datatype); + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -79,13 +76,12 @@ TEST_P( caxpyvGenericTest, RandomData ) class caxpyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); scomplex alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "caxpy_"; #elif TEST_CBLAS @@ -102,7 +98,6 @@ class caxpyvGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -120,8 +115,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha ), ::caxpyvGenericTestPrint() ); @@ -141,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2)), // stride size for x ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // alpha ), ::caxpyvGenericTestPrint() ); @@ -159,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-4)), // stride size for x ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // alpha ), ::caxpyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp index 792d582782..19d65ed5a3 100644 --- a/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/daxpyv_generic.cpp @@ -40,8 +40,7 @@ class daxpyvGenericTest : gtint_t, gtint_t, gtint_t, - double, - char>> {}; + double>> {}; // Tests using random integers as vector elements. TEST_P( daxpyvGenericTest, RandomData ) { @@ -60,8 +59,6 @@ TEST_P( daxpyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // alpha T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -69,7 +66,7 @@ TEST_P( daxpyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_axpyv(conj_x, n, incx, incy, alpha, thresh, datatype); + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -79,13 +76,12 @@ TEST_P( daxpyvGenericTest, RandomData ) class daxpyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); double alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "daxpy_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class daxpyvGenericTestPrint { str_name += "_" + incy_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -115,8 +110,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0), double(-2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0), double(-2.0)) // alpha ), ::daxpyvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0)) // alpha ), ::daxpyvGenericTestPrint() ); @@ -151,8 +144,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x ::testing::Values(gtint_t(3)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(double(4.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(4.0)) // beta ), ::daxpyvGenericTestPrint() ); @@ -169,8 +161,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-4)), // stride size for x ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(4.0), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(4.0) // alpha ), ::daxpyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp index 67699e8337..10c1daefa2 100644 --- a/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/saxpyv_generic.cpp @@ -40,8 +40,7 @@ class saxpyvGenericTest : gtint_t, gtint_t, gtint_t, - float, - char>> {}; + float>> {}; // Tests using random integers as vector elements. TEST_P( saxpyvGenericTest, RandomData ) { @@ -60,8 +59,6 @@ TEST_P( saxpyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // alpha T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -69,7 +66,7 @@ TEST_P( saxpyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_axpyv(conj_x, n, incx, incy, alpha, thresh, datatype); + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -79,13 +76,12 @@ TEST_P( saxpyvGenericTest, RandomData ) class saxpyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); float alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "saxpy_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class saxpyvGenericTestPrint { str_name += "_" + incy_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -115,8 +110,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0), float(-2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0), float(-2.0)) // alpha ), ::saxpyvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0)) // alpha ), ::saxpyvGenericTestPrint() ); @@ -149,10 +142,9 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector - ::testing::Values(gtint_t(2), gtint_t(-2)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x - ::testing::Values(gtint_t(3), gtint_t(-3)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(float(4.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(2)), // stride size for x + ::testing::Values(gtint_t(3)), // stride size for y + ::testing::Values(float(4.0)) // alpha ), ::saxpyvGenericTestPrint() ); @@ -169,8 +161,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-4)), // stride size for x ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(4.0), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(4.0) // alpha ), ::saxpyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/axpyv/test_axpyv.h b/gtestsuite/testsuite/level1/axpyv/test_axpyv.h index a2d6af583f..1cc375da00 100644 --- a/gtestsuite/testsuite/level1/axpyv/test_axpyv.h +++ b/gtestsuite/testsuite/level1/axpyv/test_axpyv.h @@ -44,12 +44,13 @@ template static void test_axpyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, - T alpha, double thresh, char datatype ) { + T alpha, double thresh ) +{ //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-10, 10, n, incy, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); //---------------------------------------------------------- // Call reference implementation to get ref results. diff --git a/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp b/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp index a8cf1a6983..64b98f1b04 100644 --- a/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp +++ b/gtestsuite/testsuite/level1/axpyv/zaxpyv_generic.cpp @@ -40,8 +40,7 @@ class zaxpyvGenericTest : gtint_t, gtint_t, gtint_t, - dcomplex, - char>> {}; + dcomplex>> {}; // Tests using random integers as vector elements. TEST_P( zaxpyvGenericTest, RandomData ) { @@ -60,15 +59,13 @@ TEST_P( zaxpyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // alpha T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_axpyv(conj_x, n, incx, incy, alpha, thresh, datatype); + test_axpyv( conj_x, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -78,13 +75,12 @@ TEST_P( zaxpyvGenericTest, RandomData ) class zaxpyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); dcomplex alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "zaxpy_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class zaxpyvGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -119,8 +114,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(dcomplex{-3.0, 1.0}, dcomplex{1.0, 2.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{-3.0, 1.0}, dcomplex{1.0, 2.0}) // alpha ), ::zaxpyvGenericTestPrint() ); @@ -140,8 +134,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2)), // stride size for x ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(dcomplex{-1.0, 2.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{-1.0, 2.0}) // alpha ), ::zaxpyvGenericTestPrint() ); @@ -158,8 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-4)), // stride size for x ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(dcomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{4.0, 3.1}) // alpha ), ::zaxpyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp index 5186cdecb5..29f988005b 100644 --- a/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/ccopyv_generic.cpp @@ -39,8 +39,7 @@ class ccopyvGenericTest : public ::testing::TestWithParam> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( ccopyvGenericTest, RandomData ) @@ -58,8 +57,6 @@ TEST_P( ccopyvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -67,7 +64,7 @@ TEST_P( ccopyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv(conjx, n, incx, incy, thresh, datatype); + test_copyv( conjx, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( ccopyvGenericTest, RandomData ) class ccopyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "ccopy_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class ccopyvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::ccopyvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::ccopyvGenericTestPrint() ); @@ -150,8 +143,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), ::ccopyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp index b97b992ba3..1c7824b8f4 100644 --- a/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/dcopyv_generic.cpp @@ -39,8 +39,7 @@ class dcopyvGenericTest : public ::testing::TestWithParam> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( dcopyvGenericTest, RandomData ) @@ -58,8 +57,6 @@ TEST_P( dcopyvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -67,7 +64,7 @@ TEST_P( dcopyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv(conjx, n, incx, incy, thresh, datatype); + test_copyv( conjx, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( dcopyvGenericTest, RandomData ) class dcopyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "dcopy_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class dcopyvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -109,8 +104,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::dcopyvGenericTestPrint() ); @@ -126,8 +120,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::dcopyvGenericTestPrint() ); @@ -143,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::dcopyvGenericTestPrint() ); @@ -160,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), ::dcopyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp index 2035f92d60..e86d2f320f 100644 --- a/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/scopyv_generic.cpp @@ -39,8 +39,7 @@ class scopyvGenericTest : public ::testing::TestWithParam> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( scopyvGenericTest, RandomData ) @@ -58,8 +57,6 @@ TEST_P( scopyvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -67,7 +64,7 @@ TEST_P( scopyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv(conjx, n, incx, incy, thresh, datatype); + test_copyv( conjx, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( scopyvGenericTest, RandomData ) class scopyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "scopy_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class scopyvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -109,8 +104,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::scopyvGenericTestPrint() ); @@ -126,8 +120,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::scopyvGenericTestPrint() ); @@ -143,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::scopyvGenericTestPrint() ); @@ -160,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), ::scopyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/copyv/test_copyv.h b/gtestsuite/testsuite/level1/copyv/test_copyv.h index 95f27925e2..6ab5a12bca 100644 --- a/gtestsuite/testsuite/level1/copyv/test_copyv.h +++ b/gtestsuite/testsuite/level1/copyv/test_copyv.h @@ -43,13 +43,12 @@ */ template -static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, - double thresh, char datatype ) { - +static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh ) +{ //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); std::vector y( testinghelpers::buff_dim(n, incy), T{-1} ); //---------------------------------------------------------- @@ -58,12 +57,12 @@ static void test_copyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, // Create a copy of y so that we can check reference results. std::vector y_ref(y); - testinghelpers::ref_copyv(conjx, n, x.data(), incx, y_ref.data(), incy); + testinghelpers::ref_copyv( conjx, n, x.data(), incx, y_ref.data(), incy ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - copyv(conjx, n, x.data(), incx, y.data(), incy); + copyv( conjx, n, x.data(), incx, y.data(), incy ); //---------------------------------------------------------- // Compute error. diff --git a/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp b/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp index b76b11386e..eeb9b13e37 100644 --- a/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp +++ b/gtestsuite/testsuite/level1/copyv/zcopyv_generic.cpp @@ -39,8 +39,7 @@ class zcopyvGenericTest : public ::testing::TestWithParam> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( zcopyvGenericTest, RandomData ) @@ -58,8 +57,6 @@ TEST_P( zcopyvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -67,7 +64,7 @@ TEST_P( zcopyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_copyv(conjx, n, incx, incy, thresh, datatype); + test_copyv( conjx, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( zcopyvGenericTest, RandomData ) class zcopyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "zcopy_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class zcopyvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ), ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::zcopyvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ), ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::zcopyvGenericTestPrint() ); @@ -150,8 +143,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-5), gtint_t(7)), // stride size for x - ::testing::Values(gtint_t(13), gtint_t(-9)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(13), gtint_t(-9)) // stride size for y ), ::zcopyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp index 3584be5f08..0a662d96b4 100644 --- a/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/cdotv_generic.cpp @@ -40,8 +40,7 @@ class cdotvGenericTest : char, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( cdotvGenericTest, RandomData ) @@ -61,8 +60,6 @@ TEST_P( cdotvGenericTest, RandomData ) gtint_t incx = std::get<3>(GetParam()); // stride size for y: gtint_t incy = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*n*testinghelpers::getEpsilon(); @@ -70,7 +67,7 @@ TEST_P( cdotvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotv(conjx, conjy, n, incx, incy, thresh, datatype); + test_dotv( conjx, conjy, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -80,13 +77,12 @@ TEST_P( cdotvGenericTest, RandomData ) class cdotvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); char conjy = std::get<1>(str.param); gtint_t n = std::get<2>(str.param); gtint_t incx = std::get<3>(str.param); gtint_t incy = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "cdotu_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class cdotvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -123,8 +118,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use y, c: use conj(y) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::cdotvGenericTestPrint() ); @@ -148,8 +142,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3)) // stride size for y ), ::cdotvGenericTestPrint() ); @@ -166,8 +159,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(-3)) // stride size for y ), ::cdotvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp index 250144e3f0..5af449fb32 100644 --- a/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/ddotv_generic.cpp @@ -40,8 +40,7 @@ class ddotvGenericTest : char, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( ddotvGenericTest, RandomData ) @@ -61,8 +60,6 @@ TEST_P( ddotvGenericTest, RandomData ) gtint_t incx = std::get<3>(GetParam()); // stride size for y: gtint_t incy = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -70,7 +67,7 @@ TEST_P( ddotvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotv(conjx, conjy, n, incx, incy, thresh, datatype); + test_dotv( conjx, conjy, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -80,13 +77,12 @@ TEST_P( ddotvGenericTest, RandomData ) class ddotvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); char conjy = std::get<1>(str.param); gtint_t n = std::get<2>(str.param); gtint_t incx = std::get<3>(str.param); gtint_t incy = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "ddot_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class ddotvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -115,8 +110,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, not conj(y) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::ddotvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::ddotvGenericTestPrint() ); @@ -151,8 +144,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // use y, not conj(y) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::ddotvGenericTestPrint() ); @@ -169,8 +161,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(-3)) // stride size for y ), ::ddotvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotv/dotv.h b/gtestsuite/testsuite/level1/dotv/dotv.h index dad9802345..7917868e56 100644 --- a/gtestsuite/testsuite/level1/dotv/dotv.h +++ b/gtestsuite/testsuite/level1/dotv/dotv.h @@ -54,16 +54,24 @@ template static void dotv_(gtint_t n, T* x, gtint_t incx, T* y, gtint_t incy, T* rho) { - if constexpr (std::is_same::value) - *rho = sdot_( &n, x, &incx, y, &incy ); - else if constexpr (std::is_same::value) - *rho = ddot_( &n, x, &incx, y, &incy ); - else if constexpr (std::is_same::value) - *rho = cdotu_( &n, x, &incx, y, &incy ); - else if constexpr (std::is_same::value) - *rho = zdotu_( &n, x, &incx, y, &incy ); - else - throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotv_()."); + if constexpr (std::is_same::value) + *rho = sdot_(&n, x, &incx, y, &incy); + else if constexpr (std::is_same::value) + *rho = ddot_( &n, x, &incx, y, &incy ); + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = cdotu_(&n, x, &incx, y, &incy); + #else + cdotu_(rho, &n, x, &incx, y, &incy); + #endif + else if constexpr (std::is_same::value) + #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + *rho = zdotu_(&n, x, &incx, y, &incy); + #else + zdotu_(rho, &n, x, &incx, y, &incy); + #endif + else + throw std::runtime_error("Error in testsuite/level1/dotv.h: Invalid typename in dotv_()."); } template diff --git a/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp index ce57c4f59b..9d69ac6e7a 100644 --- a/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/sdotv_generic.cpp @@ -40,8 +40,7 @@ class sdotvGenericTest : char, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( sdotvGenericTest, RandomData ) @@ -61,8 +60,6 @@ TEST_P( sdotvGenericTest, RandomData ) gtint_t incx = std::get<3>(GetParam()); // stride size for y: gtint_t incy = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -70,7 +67,7 @@ TEST_P( sdotvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotv(conjx, conjy, n, incx, incy, thresh, datatype); + test_dotv( conjx, conjy, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -80,13 +77,12 @@ TEST_P( sdotvGenericTest, RandomData ) class sdotvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); char conjy = std::get<1>(str.param); gtint_t n = std::get<2>(str.param); gtint_t incx = std::get<3>(str.param); gtint_t incy = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "sdot_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class sdotvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -115,8 +110,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, not conj(y) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::sdotvGenericTestPrint() ); @@ -133,8 +127,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::sdotvGenericTestPrint() ); @@ -151,8 +144,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, not conj(y) (since it is real) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::sdotvGenericTestPrint() ); @@ -169,8 +161,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(-3)) // stride size for y ), ::sdotvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotv/test_dotv.h b/gtestsuite/testsuite/level1/dotv/test_dotv.h index 1faf3120a2..3f9610f7da 100644 --- a/gtestsuite/testsuite/level1/dotv/test_dotv.h +++ b/gtestsuite/testsuite/level1/dotv/test_dotv.h @@ -44,15 +44,13 @@ template static void test_dotv( char conjx, char conjy, gtint_t n, gtint_t incx, - gtint_t incy, double thresh, char datatype ) + gtint_t incy, double thresh ) { - - //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-10, 10, n, incy, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); //---------------------------------------------------------- // Call reference implementation to get ref results. @@ -63,13 +61,13 @@ static void test_dotv( char conjx, char conjy, gtint_t n, gtint_t incx, if constexpr (testinghelpers::type_info::is_real) testinghelpers::ref_dotv( n, x.data(), incx, y_ref.data(), incy, &rho_ref ); else - testinghelpers::ref_dotv(conjx, conjy, n, x.data(), incx, y_ref.data(), incy, &rho_ref); + testinghelpers::ref_dotv( conjx, conjy, n, x.data(), incx, y_ref.data(), incy, &rho_ref ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- T rho; - dotv(conjx, conjy, n, x.data(), incx, y.data(), incy, &rho); + dotv( conjx, conjy, n, x.data(), incx, y.data(), incy, &rho ); //---------------------------------------------------------- // Compute error. diff --git a/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp b/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp index 4b0f3fbcdb..7d7d3aabd0 100644 --- a/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotv/zdotv_generic.cpp @@ -40,8 +40,7 @@ class zdotvGenericTest : char, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; // Tests using random integers as vector elements. TEST_P( zdotvGenericTest, RandomData ) @@ -61,8 +60,6 @@ TEST_P( zdotvGenericTest, RandomData ) gtint_t incx = std::get<3>(GetParam()); // stride size for y: gtint_t incy = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*n*testinghelpers::getEpsilon(); @@ -70,7 +67,7 @@ TEST_P( zdotvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotv(conjx, conjy, n, incx, incy, thresh, datatype); + test_dotv( conjx, conjy, n, incx, incy, thresh ); } // Used to generate a test case with a sensible name. @@ -80,13 +77,12 @@ TEST_P( zdotvGenericTest, RandomData ) class zdotvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conjx = std::get<0>(str.param); char conjy = std::get<1>(str.param); gtint_t n = std::get<2>(str.param); gtint_t incx = std::get<3>(str.param); gtint_t incy = std::get<4>(str.param); - char datatype = std::get<5>(str.param); #ifdef TEST_BLAS std::string str_name = "zdotu_"; #elif TEST_CBLAS @@ -101,7 +97,6 @@ class zdotvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -123,8 +118,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use y, c: use conj(y) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1)) // stride size for y ), ::zdotvGenericTestPrint() ); @@ -148,8 +142,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x - ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(3), gtint_t(33)) // stride size for y ), ::zdotvGenericTestPrint() ); @@ -166,8 +159,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use y, c: use conj(y) ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(-2)), // stride size for x - ::testing::Values(gtint_t(-3)), // stride size for y - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(-3)) // stride size for y ), ::zdotvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp index 17377a7f0c..5ed6f67d96 100644 --- a/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/cdotxv_generic.cpp @@ -36,7 +36,7 @@ #include "test_dotxv.h" class cdotxvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cdotxvGenericTest); @@ -62,8 +62,6 @@ TEST_P( cdotxvGenericTest, RandomData ) T alpha = std::get<5>(GetParam()); // beta T beta = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -71,7 +69,7 @@ TEST_P( cdotxvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh, datatype); + test_dotxv( n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,7 +79,7 @@ TEST_P( cdotxvGenericTest, RandomData ) class cdotxvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -89,7 +87,6 @@ class cdotxvGenericTestPrint { gtint_t incy = std::get<4>(str.param); scomplex alpha = std::get<5>(str.param); scomplex beta = std::get<6>(str.param); - char datatype = std::get<7>(str.param); std::string str_name = "bli_cdotxv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conjx, 1); @@ -104,7 +101,6 @@ class cdotxvGenericTestPrint { beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -121,8 +117,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(scomplex{1.0, -1.0}), // alpha - ::testing::Values(scomplex{-1.0, 1.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{-1.0, 1.0}) // beta ), ::cdotxvGenericTestPrint() ); @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(scomplex{1.0, -1.0}), // alpha - ::testing::Values(scomplex{-1.0, 1.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{-1.0, 1.0}) // beta ), ::cdotxvGenericTestPrint() ); @@ -157,10 +151,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y ::testing::Values(scomplex{1.0, -1.0}), // alpha - ::testing::Values(scomplex{-1.0, 1.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{-1.0, 1.0}) // beta ), ::cdotxvGenericTestPrint() ); - #endif diff --git a/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp index 8cd33a861e..75376ed4b9 100644 --- a/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/ddotxv_generic.cpp @@ -36,7 +36,7 @@ #include "test_dotxv.h" class ddotxvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ddotxvGenericTest); @@ -62,8 +62,6 @@ TEST_P( ddotxvGenericTest, RandomData ) T alpha = std::get<5>(GetParam()); // beta T beta = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -71,7 +69,7 @@ TEST_P( ddotxvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh, datatype); + test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,7 +79,7 @@ TEST_P( ddotxvGenericTest, RandomData ) class ddotxvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -89,7 +87,6 @@ class ddotxvGenericTestPrint { gtint_t incy = std::get<4>(str.param); double alpha = std::get<5>(str.param); double beta = std::get<6>(str.param); - char datatype = std::get<7>(str.param); std::string str_name = "bli_ddotxv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conjx, 1); @@ -102,7 +99,6 @@ class ddotxvGenericTestPrint { str_name = str_name + "_a" + alpha_str; std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -119,8 +115,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::ddotxvGenericTestPrint() ); @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::ddotxvGenericTestPrint() ); @@ -157,8 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::ddotxvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp index ea0ad22b6b..9ee47c18a7 100644 --- a/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/sdotxv_generic.cpp @@ -36,7 +36,7 @@ #include "test_dotxv.h" class sdotxvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sdotxvGenericTest); @@ -62,8 +62,6 @@ TEST_P( sdotxvGenericTest, RandomData ) T alpha = std::get<5>(GetParam()); // beta T beta = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -71,7 +69,7 @@ TEST_P( sdotxvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh, datatype); + test_dotxv( n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,7 +79,7 @@ TEST_P( sdotxvGenericTest, RandomData ) class sdotxvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -89,7 +87,6 @@ class sdotxvGenericTestPrint { gtint_t incy = std::get<4>(str.param); float alpha = std::get<5>(str.param); float beta = std::get<6>(str.param); - char datatype = std::get<7>(str.param); std::string str_name = "bli_sdotxv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conjx, 1); @@ -102,7 +99,6 @@ class sdotxvGenericTestPrint { str_name = str_name + "_a" + alpha_str; std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -119,8 +115,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::sdotxvGenericTestPrint() ); @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::sdotxvGenericTestPrint() ); @@ -157,8 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y ::testing::Values(1.0, 2.0), // alpha - ::testing::Values(2.0, 3.0), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(2.0, 3.0) // beta ), ::sdotxvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/dotxv/test_dotxv.h b/gtestsuite/testsuite/level1/dotxv/test_dotxv.h index 6d0f74d5f0..729e172b8f 100644 --- a/gtestsuite/testsuite/level1/dotxv/test_dotxv.h +++ b/gtestsuite/testsuite/level1/dotxv/test_dotxv.h @@ -44,13 +44,13 @@ template static void test_dotxv( gtint_t n, char conjx, char conjy, T alpha, - gtint_t incx, gtint_t incy, T beta, double thresh, char datatype ) + gtint_t incx, gtint_t incy, T beta, double thresh ) { //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-10, 10, n, incy, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); //---------------------------------------------------------- // Call reference implementation to get ref results. @@ -58,15 +58,15 @@ static void test_dotxv( gtint_t n, char conjx, char conjy, T alpha, // Create a copy of y so that we can check reference results. std::vector y_ref(y); T rho_ref; - testinghelpers::initone(rho_ref); - testinghelpers::ref_dotxv(conjx, conjy, n, alpha, x.data(), incx, y.data(), incy, beta, &rho_ref); + testinghelpers::initone(rho_ref); + testinghelpers::ref_dotxv( conjx, conjy, n, alpha, x.data(), incx, y.data(), incy, beta, &rho_ref ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- T rho; - testinghelpers::initone(rho); - dotxv(conjx, conjy, n, &alpha, x.data(), incx, y.data(), incy, &beta, &rho); + testinghelpers::initone(rho); + dotxv( conjx, conjy, n, &alpha, x.data(), incx, y.data(), incy, &beta, &rho ); //---------------------------------------------------------- // Compute error. diff --git a/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp b/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp index 829532afde..10bfcac45f 100644 --- a/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp +++ b/gtestsuite/testsuite/level1/dotxv/zdotxv_generic.cpp @@ -36,7 +36,7 @@ #include "test_dotxv.h" class zdotxvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zdotxvGenericTest); @@ -62,8 +62,6 @@ TEST_P( zdotxvGenericTest, RandomData ) T alpha = std::get<5>(GetParam()); // beta T beta = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = n*testinghelpers::getEpsilon(); @@ -71,7 +69,7 @@ TEST_P( zdotxvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh, datatype); + test_dotxv(n, conj_x, conj_y, alpha, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,7 +79,7 @@ TEST_P( zdotxvGenericTest, RandomData ) class zdotxvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -89,7 +87,6 @@ class zdotxvGenericTestPrint { gtint_t incy = std::get<4>(str.param); dcomplex alpha = std::get<5>(str.param); dcomplex beta = std::get<6>(str.param); - char datatype = std::get<7>(str.param); std::string str_name = "bli_zdotxv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conjx, 1); @@ -104,7 +101,6 @@ class zdotxvGenericTestPrint { beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -121,8 +117,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y ::testing::Values(dcomplex{1.0, -1.0}), // alpha - ::testing::Values(dcomplex{-1.0, 1.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{-1.0, 1.0}) // beta ), ::zdotxvGenericTestPrint() ); @@ -140,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), // stride size for y ::testing::Values(dcomplex{1.0, -1.0}), // alpha - ::testing::Values(dcomplex{-1.0, 1.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{-1.0, 1.0}) // beta ), ::zdotxvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp index d25419606f..e9c1d53189 100644 --- a/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/cscal2v_generic.cpp @@ -40,38 +40,35 @@ class cscal2vGenericTest : gtint_t, gtint_t, gtint_t, - scomplex, - char>> {}; + scomplex>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cscal2vGenericTest); // Tests using random integers as vector elements. TEST_P( cscal2vGenericTest, RandomData ) { - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scal2v(conj_alpha, n, incx, incy, alpha, thresh, datatype); + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -81,13 +78,12 @@ TEST_P( cscal2vGenericTest, RandomData ) class cscal2vGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); scomplex alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_cscal2v"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -98,7 +94,6 @@ class cscal2vGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -112,8 +107,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha ), ::cscal2vGenericTestPrint() ); @@ -130,8 +124,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(4)), // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // alpha ), ::cscal2vGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp index 396bf99ba1..66b624c382 100644 --- a/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/dscal2v_generic.cpp @@ -40,38 +40,35 @@ class dscal2vGenericTest : gtint_t, gtint_t, gtint_t, - double, - char>> {}; + double>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dscal2vGenericTest); // Tests using random integers as vector elements. TEST_P( dscal2vGenericTest, RandomData ) { - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); - // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scal2v(conj_alpha, n, incx, incy, alpha, thresh, datatype); + // Set the threshold for the errors: + float thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -81,13 +78,12 @@ TEST_P( dscal2vGenericTest, RandomData ) class dscal2vGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); double alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_dscal2v"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -97,7 +93,6 @@ class dscal2vGenericTestPrint { str_name += "_" + incy_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -111,8 +106,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0), double(-3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0), double(-3.0)) // alpha ), ::dscal2vGenericTestPrint() ); @@ -128,8 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(-3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(-3.0)) // alpha ), ::dscal2vGenericTestPrint() ); @@ -145,8 +138,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(5)), // stride size for y - ::testing::Values(double(3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(3.0)) // alpha ), ::dscal2vGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp index ef02a4c225..366d649ead 100644 --- a/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/sscal2v_generic.cpp @@ -40,38 +40,35 @@ class sscal2vGenericTest : gtint_t, gtint_t, gtint_t, - float, - char>> {}; + float>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sscal2vGenericTest); // Tests using random integers as vector elements. TEST_P( sscal2vGenericTest, RandomData ) { - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); - // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scal2v(conj_alpha, n, incx, incy, alpha, thresh, datatype); + // Set the threshold for the errors: + float thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -81,13 +78,12 @@ TEST_P( sscal2vGenericTest, RandomData ) class sscal2vGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); float alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_sscal2v"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -97,10 +93,10 @@ class sscal2vGenericTestPrint { str_name += "_" + incy_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; + #ifdef TEST_BLIS_TYPED // Black box testing for generic and main use of sscal2. INSTANTIATE_TEST_SUITE_P( @@ -111,8 +107,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(3.0), float(-5.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(3.0), float(-5.0)) // alpha ), ::sscal2vGenericTestPrint() ); @@ -128,8 +123,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(9.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(9.0)) // alpha ), ::sscal2vGenericTestPrint() ); @@ -145,8 +139,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(7)), // stride size for y - ::testing::Values(float(2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0)) // alpha ), ::sscal2vGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scal2v/test_scal2v.h b/gtestsuite/testsuite/level1/scal2v/test_scal2v.h index 8edb967ab2..9cb621acb6 100644 --- a/gtestsuite/testsuite/level1/scal2v/test_scal2v.h +++ b/gtestsuite/testsuite/level1/scal2v/test_scal2v.h @@ -43,12 +43,12 @@ */ template -static void test_scal2v(char conjx, gtint_t n, gtint_t incx, gtint_t incy, T alpha, double thresh, char datatype) +static void test_scal2v(char conjx, gtint_t n, gtint_t incx, gtint_t incy, T alpha, double thresh ) { //---------------------------------------------------------- // Initialize vector with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); std::vector y( testinghelpers::buff_dim(n, incy), T{-112} ); //---------------------------------------------------------- @@ -56,12 +56,12 @@ static void test_scal2v(char conjx, gtint_t n, gtint_t incx, gtint_t incy, T alp //---------------------------------------------------------- // Create a copy of y so that we can check reference results. std::vector y_ref(y); - testinghelpers::ref_scal2v(conjx, n, alpha, x.data(), incx, y_ref.data(), incy); + testinghelpers::ref_scal2v( conjx, n, alpha, x.data(), incx, y_ref.data(), incy ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - scal2v(conjx, n, alpha, x.data(), incx, y.data(), incy); + scal2v( conjx, n, alpha, x.data(), incx, y.data(), incy ); //---------------------------------------------------------- // Compute component-wise error. diff --git a/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp b/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp index 0308cbd10b..5c413192d6 100644 --- a/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp +++ b/gtestsuite/testsuite/level1/scal2v/zscal2v_generic.cpp @@ -40,8 +40,7 @@ class zscal2vGenericTest : gtint_t, gtint_t, gtint_t, - dcomplex, - char>> {}; + dcomplex>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscal2vGenericTest); @@ -49,30 +48,28 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zscal2vGenericTest); // Tests using random integers as vector elements. TEST_P( zscal2vGenericTest, RandomData ) { - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // stride size for y: - gtint_t incy = std::get<3>(GetParam()); - // alpha - T alpha = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // stride size for y: + gtint_t incy = std::get<3>(GetParam()); + // alpha + T alpha = std::get<4>(GetParam()); - // Set the threshold for the errors: - float thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scal2v(conj_alpha, n, incx, incy, alpha, thresh, datatype); + // Set the threshold for the errors: + float thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scal2v( conj_alpha, n, incx, incy, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -82,13 +79,12 @@ TEST_P( zscal2vGenericTest, RandomData ) class zscal2vGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); dcomplex alpha = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_zscal2v"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -99,7 +95,6 @@ class zscal2vGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}) // alpha ), ::zscal2vGenericTestPrint() ); @@ -131,8 +125,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), // stride size for x ::testing::Values(gtint_t(3)), // stride size for y - ::testing::Values(dcomplex{1.0, 2.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{1.0, 2.1}) // alpha ), ::zscal2vGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp index 223fec91d7..bf367f73d8 100644 --- a/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp +++ b/gtestsuite/testsuite/level1/scalv/cscalv_generic.cpp @@ -39,35 +39,32 @@ class cscalvGenericTest : public ::testing::TestWithParam> {}; + scomplex>> {}; // Tests using random integers as vector elements. TEST_P( cscalvGenericTest, RandomData ) { - using T = scomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv(conj_alpha, n, incx, alpha, thresh, datatype); + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( cscalvGenericTest, RandomData ) class cscalvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); scomplex alpha = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "cscal_"; #elif TEST_CBLAS @@ -97,7 +93,6 @@ class cscalvGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -114,8 +109,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // alpha ), ::cscalvGenericTestPrint() ); @@ -135,8 +129,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // alpha ), ::cscalvGenericTestPrint() ); @@ -152,8 +145,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(scomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // alpha ), ::cscalvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp index 6410481560..b73db053c6 100644 --- a/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp +++ b/gtestsuite/testsuite/level1/scalv/dscalv_generic.cpp @@ -39,35 +39,32 @@ class dscalvGenericTest : public ::testing::TestWithParam> {}; + double>> {}; // Tests using random integers as vector elements. TEST_P( dscalvGenericTest, RandomData ) { - using T = double; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv(conj_alpha, n, incx, alpha, thresh, datatype); + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( dscalvGenericTest, RandomData ) class dscalvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); double alpha = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "dscal_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class dscalvGenericTestPrint { str_name += "_" + incx_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -109,8 +104,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(double(2.0), double(-3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0), double(-3.0)) // alpha ), ::dscalvGenericTestPrint() ); @@ -126,8 +120,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conjugate ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(double(-3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(-3.0)) // alpha ), ::dscalvGenericTestPrint() ); @@ -143,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(double(3.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(3.0)) // alpha ), ::dscalvGenericTestPrint() ); @@ -160,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(3), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(3) // alpha ), ::dscalvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp b/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp index df350f91b5..9ac6c0d4ed 100644 --- a/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp +++ b/gtestsuite/testsuite/level1/scalv/scalv_extreme_cases.cpp @@ -46,7 +46,7 @@ TYPED_TEST(xscalv, zero_alpha_x_fp) gtint_t n = 10, incx = 1; std::vector x(n); // Initialize x with random numbers. - testinghelpers::datagenerators::randomgenerators(n, incx, x.data(), 'f'); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x.data(), BLIS_ELEMENT_TYPE ); std::vector x_ref(x); T alpha = T{0}; @@ -70,7 +70,7 @@ TYPED_TEST(xscalv, zero_alpha_x_inf) gtint_t n = 10, incx = 1; std::vector x(n); // Initialize x with random numbers. - testinghelpers::datagenerators::randomgenerators(n, incx, x.data(), 'f'); + testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x.data(), BLIS_ELEMENT_TYPE ); x[3] = 1.0/0.0; std::vector x_ref(x); T alpha = T{0}; diff --git a/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp index 7e37a0e8fc..e00f5effa2 100644 --- a/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp +++ b/gtestsuite/testsuite/level1/scalv/sscalv_generic.cpp @@ -39,35 +39,32 @@ class sscalvGenericTest : public ::testing::TestWithParam> {}; + float>> {}; // Tests using random integers as vector elements. TEST_P( sscalvGenericTest, RandomData ) { - using T = float; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv(conj_alpha, n, incx, alpha, thresh, datatype); + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); + + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( sscalvGenericTest, RandomData ) class sscalvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); float alpha = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "sscal_"; #elif TEST_CBLAS @@ -96,7 +92,6 @@ class sscalvGenericTestPrint { str_name += "_" + incx_str; std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -109,8 +104,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, not conj(x) (since it is real) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(float(3.0), float(-5.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(3.0), float(-5.0)) // alpha ), ::sscalvGenericTestPrint() ); @@ -126,8 +120,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('c'), // c: use conjugate ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(float(9.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(9.0)) // alpha ), ::sscalvGenericTestPrint() ); @@ -143,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(float(2.0)), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0)) // alpha ), ::sscalvGenericTestPrint() ); @@ -161,8 +153,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(3), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(3) // alpha ), ::sscalvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/scalv/test_scalv.h b/gtestsuite/testsuite/level1/scalv/test_scalv.h index bfe7f9bfde..4c5437d722 100644 --- a/gtestsuite/testsuite/level1/scalv/test_scalv.h +++ b/gtestsuite/testsuite/level1/scalv/test_scalv.h @@ -43,24 +43,24 @@ */ template -static void test_scalv(char conja_alpha, gtint_t n, gtint_t incx, T alpha, double thresh, char datatype) +static void test_scalv( char conja_alpha, gtint_t n, gtint_t incx, T alpha, double thresh ) { //---------------------------------------------------------- // Initialize vector with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- // Create a copy of y so that we can check reference results. std::vector x_ref(x); - testinghelpers::ref_scalv(conja_alpha, n, alpha, x_ref.data(), incx); + testinghelpers::ref_scalv( conja_alpha, n, alpha, x_ref.data(), incx ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - scalv(conja_alpha, n, alpha, x.data(), incx); + scalv( conja_alpha, n, alpha, x.data(), incx ); //---------------------------------------------------------- // Compute component-wise error. diff --git a/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp b/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp index 6ddf2489d9..66419cbd4c 100644 --- a/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp +++ b/gtestsuite/testsuite/level1/scalv/zscalv_generic.cpp @@ -39,35 +39,32 @@ class zscalvGenericTest : public ::testing::TestWithParam> {}; + dcomplex>> {}; // Tests using random integers as vector elements. TEST_P( zscalvGenericTest, RandomData ) { - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // denotes whether alpha or conj(alpha) will be used: - char conj_alpha = std::get<0>(GetParam()); - // vector length: - gtint_t n = std::get<1>(GetParam()); - // stride size for x: - gtint_t incx = std::get<2>(GetParam()); - // alpha - T alpha = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // denotes whether alpha or conj(alpha) will be used: + char conj_alpha = std::get<0>(GetParam()); + // vector length: + gtint_t n = std::get<1>(GetParam()); + // stride size for x: + gtint_t incx = std::get<2>(GetParam()); + // alpha + T alpha = std::get<3>(GetParam()); - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - //---------------------------------------------------------- - // Call generic test body using those parameters - //---------------------------------------------------------- - test_scalv(conj_alpha, n, incx, alpha, thresh, datatype); + // Set the threshold for the errors: + double thresh = testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call generic test body using those parameters + //---------------------------------------------------------- + test_scalv( conj_alpha, n, incx, alpha, thresh ); } // Used to generate a test case with a sensible name. @@ -77,12 +74,11 @@ TEST_P( zscalvGenericTest, RandomData ) class zscalvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); dcomplex alpha = std::get<3>(str.param); - char datatype = std::get<4>(str.param); #ifdef TEST_BLAS std::string str_name = "zscal_"; #elif TEST_CBLAS @@ -97,7 +93,6 @@ class zscalvGenericTestPrint { std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -114,8 +109,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{3.0, -2.0}, dcomplex{-1.0, 4.0}) // alpha ), ::zscalvGenericTestPrint() ); @@ -135,8 +129,7 @@ INSTANTIATE_TEST_SUITE_P( ), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), //(gtint_t(-5), gtint_t(-17)) // stride size for x - ::testing::Values(dcomplex{1.0, 2.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{1.0, 2.1}) // alpha ), ::zscalvGenericTestPrint() ); @@ -152,8 +145,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: use x, c: use conj(x) ::testing::Range(gtint_t(10), gtint_t(31), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(-2), gtint_t(-1)), // stride size for x - ::testing::Values(dcomplex{4.0, 3.1}), // alpha - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{4.0, 3.1}) // alpha ), ::zscalvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/setv/test_setv.h b/gtestsuite/testsuite/level1/setv/test_setv.h index 09bd121f6e..da98788ecc 100644 --- a/gtestsuite/testsuite/level1/setv/test_setv.h +++ b/gtestsuite/testsuite/level1/setv/test_setv.h @@ -43,7 +43,8 @@ */ template -void test_setv( char conjalpha, gtint_t n, T alpha, gtint_t incx ) { +void test_setv( char conjalpha, gtint_t n, T alpha, gtint_t incx ) +{ //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- @@ -60,7 +61,7 @@ void test_setv( char conjalpha, gtint_t n, T alpha, gtint_t incx ) { //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - setv( conjalpha, n, &alpha, x.data(), incx ); + setv( conjalpha, n, &alpha, x.data(), incx ); //---------------------------------------------------------- // Compute component-wise error. diff --git a/gtestsuite/testsuite/level1/subv/csubv_generic.cpp b/gtestsuite/testsuite/level1/subv/csubv_generic.cpp index 7b98a8ebfb..70797d5e5a 100644 --- a/gtestsuite/testsuite/level1/subv/csubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/csubv_generic.cpp @@ -36,7 +36,7 @@ #include "test_subv.h" class csubvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(csubvGenericTest); @@ -55,8 +55,6 @@ TEST_P( csubvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -64,19 +62,18 @@ TEST_P( csubvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_subv(conj_x, n, incx, incy, thresh, datatype); + test_subv( conj_x, n, incx, incy, thresh ); } // Prints the test case combination class csubvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); std::string str_name = "bli_csubv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -84,7 +81,6 @@ class csubvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -98,8 +94,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)), // stride size for y - ::testing::Values(ELEMENT_TYPE,'f') // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y ), ::csubvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp b/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp index 9b31bcb102..63a63a9274 100644 --- a/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/dsubv_generic.cpp @@ -36,7 +36,7 @@ #include "test_subv.h" class dsubvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dsubvGenericTest); @@ -55,8 +55,6 @@ TEST_P( dsubvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -64,19 +62,18 @@ TEST_P( dsubvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_subv(conj_x, n, incx, incy, thresh, datatype); + test_subv( conj_x, n, incx, incy, thresh ); } // Prints the test case combination class dsubvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); std::string str_name = "bli_dsubv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -84,7 +81,6 @@ class dsubvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -98,8 +94,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)), // stride size for y - ::testing::Values(ELEMENT_TYPE,'f') // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y ), ::dsubvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp b/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp index 4d96efc4e1..50e004cb07 100644 --- a/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/ssubv_generic.cpp @@ -36,7 +36,7 @@ #include "test_subv.h" class ssubvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ssubvGenericTest); @@ -55,8 +55,6 @@ TEST_P( ssubvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -64,19 +62,18 @@ TEST_P( ssubvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_subv(conj_x, n, incx, incy, thresh, datatype); + test_subv( conj_x, n, incx, incy, thresh ); } // Prints the test case combination class ssubvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); std::string str_name = "bli_ssubv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -84,7 +81,6 @@ class ssubvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -98,8 +94,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n'), // n: not transpose for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)), // stride size for y - ::testing::Values(ELEMENT_TYPE,'f') // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y ), ::ssubvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/subv/test_subv.h b/gtestsuite/testsuite/level1/subv/test_subv.h index 9406823bd3..ffdf86a3db 100644 --- a/gtestsuite/testsuite/level1/subv/test_subv.h +++ b/gtestsuite/testsuite/level1/subv/test_subv.h @@ -43,29 +43,28 @@ */ template -void test_subv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, - double thresh, char datatype ) { +void test_subv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, double thresh ) +{ //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-10, 10, n, incy, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- // Create a copy of y so that we can check reference results. std::vector y_ref(y); - testinghelpers::ref_subv(conjx, n, x.data(), incx, y_ref.data(), incy); + testinghelpers::ref_subv( conjx, n, x.data(), incx, y_ref.data(), incy ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - subv(conjx, n, x.data(), incx, y.data(), incy); + subv( conjx, n, x.data(), incx, y.data(), incy ); //---------------------------------------------------------- // Compute component-wise error. //---------------------------------------------------------- computediff( n, y.data(), y_ref.data(), incy, thresh ); - } diff --git a/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp b/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp index 2fa7236e64..f4e634f4c5 100644 --- a/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp +++ b/gtestsuite/testsuite/level1/subv/zsubv_generic.cpp @@ -36,7 +36,7 @@ #include "test_subv.h" class zsubvGenericTest : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zsubvGenericTest); @@ -55,8 +55,6 @@ TEST_P( zsubvGenericTest, RandomData ) gtint_t incx = std::get<2>(GetParam()); // stride size for y: gtint_t incy = std::get<3>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<4>(GetParam()); // Set the threshold for the errors: double thresh = testinghelpers::getEpsilon(); @@ -64,19 +62,18 @@ TEST_P( zsubvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_subv(conj_x, n, incx, incy, thresh, datatype); + test_subv( conj_x, n, incx, incy, thresh ); } // Prints the test case combination class zsubvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); - char datatype = std::get<4>(str.param); std::string str_name = "bli_zsubv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -84,7 +81,6 @@ class zsubvGenericTestPrint { str_name += "_" + incx_str; std::string incy_str = ( incy > 0) ? std::to_string(incy) : "m" + std::to_string(std::abs(incy)); str_name += "_" + incy_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -98,8 +94,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values('n','c'), // n: not transpose for x, c: conjugate for x ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1), gtint_t(4)), // stride size for x - ::testing::Values(gtint_t(1), gtint_t(7)), // stride size for y - ::testing::Values(ELEMENT_TYPE,'f') // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(1), gtint_t(7)) // stride size for y ), ::zsubvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp index 7af0647138..6fb81b92aa 100644 --- a/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/cxpbyv_generic.cpp @@ -40,8 +40,7 @@ class cxpbyvGenericTest : gtint_t, gtint_t, gtint_t, - scomplex, - char>> {}; + scomplex>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cxpbyvGenericTest); @@ -63,15 +62,13 @@ TEST_P( cxpbyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // beta T beta = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_xpbyv(conj_x, n, incx, incy, beta, thresh, datatype); + test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,13 +78,12 @@ TEST_P( cxpbyvGenericTest, RandomData ) class cxpbyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); scomplex beta = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_cxpbyv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -98,7 +94,6 @@ class cxpbyvGenericTestPrint { std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}) // beta ), ::cxpbyvGenericTestPrint() ); @@ -130,8 +124,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), /*(gtint_t(-5), gtint_t(-17))*/ // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y - ::testing::Values(scomplex{4.0, 3.1}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(scomplex{4.0, 3.1}) // beta ), ::cxpbyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp index 15e06808c0..079867f1f4 100644 --- a/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/dxpbyv_generic.cpp @@ -40,8 +40,7 @@ class dxpbyvGenericTest : gtint_t, gtint_t, gtint_t, - double, - char>> {}; + double>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dxpbyvGenericTest); @@ -63,8 +62,6 @@ TEST_P( dxpbyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // beta T beta = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*testinghelpers::getEpsilon(); @@ -72,7 +69,7 @@ TEST_P( dxpbyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_xpbyv(conj_x, n, incx, incy, beta, thresh, datatype); + test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -82,13 +79,12 @@ TEST_P( dxpbyvGenericTest, RandomData ) class dxpbyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); double beta = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_dxpbyv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -98,7 +94,6 @@ class dxpbyvGenericTestPrint { str_name += "_" + incy_str; std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0), double(-2.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0), double(-2.0)) // beta ), ::dxpbyvGenericTestPrint() ); @@ -131,8 +125,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(double(2.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(2.0)) // beta ), ::dxpbyvGenericTestPrint() ); @@ -149,8 +142,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(double(4.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(double(4.0)) // beta ), ::dxpbyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp index b424025ce7..fe33a81cb8 100644 --- a/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/sxpbyv_generic.cpp @@ -40,8 +40,7 @@ class sxpbyvGenericTest : gtint_t, gtint_t, gtint_t, - float, - char>> {}; + float>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sxpbyvGenericTest); @@ -63,8 +62,6 @@ TEST_P( sxpbyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // beta T beta = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: float thresh = 2*testinghelpers::getEpsilon(); @@ -72,7 +69,7 @@ TEST_P( sxpbyvGenericTest, RandomData ) //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_xpbyv(conj_x, n, incx, incy, beta, thresh, datatype); + test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -82,13 +79,12 @@ TEST_P( sxpbyvGenericTest, RandomData ) class sxpbyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); float beta = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_sxpbyv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -98,7 +94,6 @@ class sxpbyvGenericTestPrint { str_name += "_" + incy_str; std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0), float(-2.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0), float(-2.0)) // beta ), ::sxpbyvGenericTestPrint() ); @@ -130,8 +124,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(float(2.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(2.0)) // beta ), ::sxpbyvGenericTestPrint() ); @@ -148,8 +141,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(gtint_t(3), gtint_t(30), gtint_t(112)), // m size of vector ::testing::Values(gtint_t(2), gtint_t(11)), /*(gtint_t(-5), gtint_t(-17))*/// stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/// stride size for y - ::testing::Values(float(4.0)), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(float(4.0)) // beta ), ::sxpbyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h b/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h index 46af04c30e..1694c2149d 100644 --- a/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h +++ b/gtestsuite/testsuite/level1/xpbyv/test_xpbyv.h @@ -43,26 +43,26 @@ */ template -static void test_xpbyv(char conjx, gtint_t n, gtint_t incx, gtint_t incy, - T beta, double thresh, char datatype ) { - +static void test_xpbyv( char conjx, gtint_t n, gtint_t incx, gtint_t incy, + T beta, double thresh ) +{ //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-10, 10, n, incy, datatype); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); + std::vector y = testinghelpers::get_random_vector( -10, 10, n, incy ); //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- // Create a copy of y so that we can check reference results. std::vector y_ref(y); - testinghelpers::ref_xpbyv(conjx, n, x.data(), incx, beta, y_ref.data(), incy); + testinghelpers::ref_xpbyv( conjx, n, x.data(), incx, beta, y_ref.data(), incy ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - xpbyv(conjx, n, x.data(), incx, beta, y.data(), incy); + xpbyv( conjx, n, x.data(), incx, beta, y.data(), incy ); //---------------------------------------------------------- // Compute component-wise error. diff --git a/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp b/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp index cea3e8a086..04b781da8c 100644 --- a/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp +++ b/gtestsuite/testsuite/level1/xpbyv/zxpbyv_generic.cpp @@ -40,8 +40,7 @@ class zxpbyvGenericTest : gtint_t, gtint_t, gtint_t, - dcomplex, - char>> {}; + dcomplex>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zxpbyvGenericTest); @@ -63,15 +62,13 @@ TEST_P( zxpbyvGenericTest, RandomData ) gtint_t incy = std::get<3>(GetParam()); // beta T beta = std::get<4>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<5>(GetParam()); // Set the threshold for the errors: double thresh = 2*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call generic test body using those parameters //---------------------------------------------------------- - test_xpbyv(conj_x, n, incx, incy, beta, thresh, datatype); + test_xpbyv( conj_x, n, incx, incy, beta, thresh ); } // Used to generate a test case with a sensible name. @@ -81,13 +78,12 @@ TEST_P( zxpbyvGenericTest, RandomData ) class zxpbyvGenericTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char conj = std::get<0>(str.param); gtint_t n = std::get<1>(str.param); gtint_t incx = std::get<2>(str.param); gtint_t incy = std::get<3>(str.param); dcomplex beta = std::get<4>(str.param); - char datatype = std::get<5>(str.param); std::string str_name = "bli_zxpbyv"; str_name += "_" + std::to_string(n); str_name += "_" + std::string(&conj, 1); @@ -98,7 +94,6 @@ class zxpbyvGenericTestPrint { std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); str_name = str_name + "_b" + beta_str; - str_name = str_name + "_" + datatype; return str_name; } }; @@ -113,8 +108,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(1)), /*(gtint_t(-5), gtint_t(-17))*/ // stride size for x ::testing::Values(gtint_t(1)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y - ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}) // beta ), ::zxpbyvGenericTestPrint() ); @@ -130,8 +124,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. ::testing::Values(gtint_t(2), gtint_t(11)), /*(gtint_t(-5), gtint_t(-17))*/ // stride size for x ::testing::Values(gtint_t(3), gtint_t(33)), /*(gtint_t(-12), gtint_t(-4))*/ // stride size for y - ::testing::Values(dcomplex{4.0, 3.1}), // beta - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(dcomplex{4.0, 3.1}) // beta ), ::zxpbyvGenericTestPrint() ); diff --git a/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp index 8c0cb5200a..8ba1f7a429 100644 --- a/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp +++ b/gtestsuite/testsuite/level2/gemv/cgemv_generic.cpp @@ -45,12 +45,11 @@ class cgemvTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cgemvTest, RandomData) { +TEST_P(cgemvTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,22 +76,20 @@ TEST_P(cgemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemv(storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); } class cgemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char transa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -103,13 +100,12 @@ class cgemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "cgemv_"; #elif TEST_CBLAS std::string str_name = "cblas_cgemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cgemv"; + std::string str_name = "bli_cgemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + transa+conjx; @@ -126,7 +122,6 @@ class cgemvTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -149,8 +144,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{-1.0, 1.0}), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of a ), ::cgemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp index 4fc91b1f46..33cc9fa57b 100644 --- a/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp +++ b/gtestsuite/testsuite/level2/gemv/dgemv_generic.cpp @@ -45,12 +45,11 @@ class dgemvTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dgemvTest, RandomData) { +TEST_P(dgemvTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,22 +76,20 @@ TEST_P(dgemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemv(storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); } class dgemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char transa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -103,13 +100,12 @@ class dgemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "dgemv_"; #elif TEST_CBLAS std::string str_name = "cblas_dgemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dgemv"; + std::string str_name = "bli_dgemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + transa+conjx; @@ -124,7 +120,6 @@ class dgemvTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -147,8 +142,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0 ), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of a ), ::dgemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp index a6906559eb..ec726ff56b 100644 --- a/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp +++ b/gtestsuite/testsuite/level2/gemv/sgemv_generic.cpp @@ -45,12 +45,11 @@ class sgemvTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(sgemvTest, RandomData) { +TEST_P(sgemvTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,22 +76,20 @@ TEST_P(sgemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemv(storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); } class sgemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char transa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -103,13 +100,12 @@ class sgemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "sgemv_"; #elif TEST_CBLAS std::string str_name = "cblas_sgemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_sgemv"; + std::string str_name = "bli_sgemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + transa+conjx; @@ -124,7 +120,6 @@ class sgemvTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -147,8 +142,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0 ), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of a ), ::sgemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/gemv/test_gemv.h b/gtestsuite/testsuite/level2/gemv/test_gemv.h index 7d3dfc14d6..76f8970294 100644 --- a/gtestsuite/testsuite/level2/gemv/test_gemv.h +++ b/gtestsuite/testsuite/level2/gemv/test_gemv.h @@ -43,11 +43,10 @@ template void test_gemv( char storage, char trnsa, char conjx, gtint_t m, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, - double thresh, char datatype ) { - + T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', m, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); // Get correct vector lengths. gtint_t lenx = ( testinghelpers::chknotrans( trnsa ) ) ? n : m ; @@ -56,22 +55,22 @@ void test_gemv( char storage, char trnsa, char conjx, gtint_t m, gtint_t n, //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(1, 5, storage, 'n', m, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(1, 3, lenx, incx, datatype); - std::vector y = testinghelpers::get_random_vector(1, 3, leny, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( 1, 5, storage, 'n', m, n, lda ); + std::vector x = testinghelpers::get_random_vector( 1, 3, lenx, incx ); + std::vector y = testinghelpers::get_random_vector( 1, 3, leny, incy ); // Create a copy of c so that we can check reference results. std::vector y_ref(y); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - gemv( storage, trnsa, conjx, m, n, &alpha, a.data(), lda, + gemv( storage, trnsa, conjx, m, n, &alpha, a.data(), lda, x.data(), incx, &beta, y.data(), incy ); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_gemv( storage, trnsa, conjx, m, n, alpha, a.data(), + testinghelpers::ref_gemv( storage, trnsa, conjx, m, n, alpha, a.data(), lda, x.data(), incx, beta, y_ref.data(), incy ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp b/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp index 74d95b5b13..8c27717111 100644 --- a/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp +++ b/gtestsuite/testsuite/level2/gemv/zgemv_generic.cpp @@ -45,12 +45,11 @@ class zgemvTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zgemvTest, RandomData) { +TEST_P(zgemvTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,22 +76,20 @@ TEST_P(zgemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemv(storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_gemv( storage, transa, conjx, m, n, alpha, lda_inc, incx, beta, incy, thresh ); } class zgemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char transa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -103,13 +100,12 @@ class zgemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "zgemv_"; #elif TEST_CBLAS std::string str_name = "cblas_zgemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zgemv"; + std::string str_name = "bli_zgemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + transa+conjx; @@ -126,7 +122,6 @@ class zgemvTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -149,8 +144,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{-1.0, 1.0}), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of a ), ::zgemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/ger/cger_generic.cpp b/gtestsuite/testsuite/level2/ger/cger_generic.cpp index 7dcd4fea70..024ac6d4da 100644 --- a/gtestsuite/testsuite/level2/ger/cger_generic.cpp +++ b/gtestsuite/testsuite/level2/ger/cger_generic.cpp @@ -44,12 +44,11 @@ class cgerTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cgerTest, RandomData) { +TEST_P(cgerTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,22 +73,20 @@ TEST_P(cgerTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_ger(storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); } class cgerTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -99,13 +96,12 @@ class cgerTestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "cger_"; #elif TEST_CBLAS std::string str_name = "cblas_cger"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cger"; + std::string str_name = "bli_cger"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + conjx+conjy; @@ -119,7 +115,6 @@ class cgerTestPrint { alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -141,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0, -2.0}), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::cgerTestPrint() ); diff --git a/gtestsuite/testsuite/level2/ger/dger_generic.cpp b/gtestsuite/testsuite/level2/ger/dger_generic.cpp index 043a165407..1fd5efa4f2 100644 --- a/gtestsuite/testsuite/level2/ger/dger_generic.cpp +++ b/gtestsuite/testsuite/level2/ger/dger_generic.cpp @@ -44,12 +44,11 @@ class dgerTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dgerTest, RandomData) { +TEST_P(dgerTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,22 +73,20 @@ TEST_P(dgerTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_ger(storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); } class dgerTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -99,13 +96,12 @@ class dgerTestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "dger_"; #elif TEST_CBLAS std::string str_name = "cblas_dger"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dger"; + std::string str_name = "bli_dger"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + conjx+conjy; @@ -118,7 +114,6 @@ class dgerTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( 1.0 ), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::dgerTestPrint() ); diff --git a/gtestsuite/testsuite/level2/ger/sger_generic.cpp b/gtestsuite/testsuite/level2/ger/sger_generic.cpp index 113dee0342..37c832759d 100644 --- a/gtestsuite/testsuite/level2/ger/sger_generic.cpp +++ b/gtestsuite/testsuite/level2/ger/sger_generic.cpp @@ -44,12 +44,11 @@ class sgerTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(sgerTest, RandomData) { +TEST_P(sgerTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,22 +73,20 @@ TEST_P(sgerTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 4*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 4*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_ger(storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); } class sgerTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -99,13 +96,12 @@ class sgerTestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "sger_"; #elif TEST_CBLAS std::string str_name = "cblas_sger"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_sger"; + std::string str_name = "bli_sger"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + conjx+conjy; @@ -118,7 +114,6 @@ class sgerTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( 1.0 ), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::sgerTestPrint() ); diff --git a/gtestsuite/testsuite/level2/ger/test_ger.h b/gtestsuite/testsuite/level2/ger/test_ger.h index a85a13a7e9..3e8e7646d8 100644 --- a/gtestsuite/testsuite/level2/ger/test_ger.h +++ b/gtestsuite/testsuite/level2/ger/test_ger.h @@ -43,31 +43,30 @@ template void test_ger( char storage, char conjx, char conjy, gtint_t m, gtint_t n, - T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh, - char datatype ) { - + T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', m, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', m, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', m, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, m, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-3, 3, n, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', m, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, m, incx ); + std::vector y = testinghelpers::get_random_vector( -3, 3, n, incy ); // Create a copy of c so that we can check reference results. std::vector a_ref(a); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- - ger( storage, conjx, conjy, m, n, &alpha, x.data(), incx, + ger( storage, conjx, conjy, m, n, &alpha, x.data(), incx, y.data(), incy, a.data(), lda ); //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_ger( storage, conjx, conjy, m, n, alpha, + testinghelpers::ref_ger( storage, conjx, conjy, m, n, alpha, x.data(), incx, y.data(), incy, a_ref.data(), lda ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level2/ger/zger_generic.cpp b/gtestsuite/testsuite/level2/ger/zger_generic.cpp index 0f32161eaa..5847842c30 100644 --- a/gtestsuite/testsuite/level2/ger/zger_generic.cpp +++ b/gtestsuite/testsuite/level2/ger/zger_generic.cpp @@ -44,12 +44,11 @@ class zgerTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zgerTest, RandomData) { +TEST_P(zgerTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,22 +73,20 @@ TEST_P(zgerTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: - double thresh = 2*std::max(m,n)*testinghelpers::getEpsilon(); + double thresh = 2*(std::max)(m,n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_ger(storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_ger( storage, conjx, conjy, m, n, alpha, incx, incy, lda_inc, thresh ); } class zgerTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char conjx = std::get<1>(str.param); char conjy = std::get<2>(str.param); @@ -99,13 +96,12 @@ class zgerTestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "zger_"; #elif TEST_CBLAS std::string str_name = "cblas_zger"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zger"; + std::string str_name = "bli_zger"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + conjx+conjy; @@ -119,7 +115,6 @@ class zgerTestPrint { alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -141,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0, -2.0}), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::zgerTestPrint() ); diff --git a/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp b/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp index ed650d0229..ed4b726817 100644 --- a/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp +++ b/gtestsuite/testsuite/level2/hemv/chemv_generic.cpp @@ -45,12 +45,11 @@ class chemvTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(chemvTest, RandomData) { +TEST_P(chemvTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,8 +76,6 @@ TEST_P(chemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); @@ -86,13 +83,13 @@ TEST_P(chemvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_hemv(storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_hemv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } class chemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conja = std::get<2>(str.param); @@ -103,13 +100,12 @@ class chemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "chemv_"; #elif TEST_CBLAS std::string str_name = "cblas_chemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_chemv"; + std::string str_name = "bli_chemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conja+conjx; @@ -125,7 +121,6 @@ class chemvTestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -148,8 +143,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{2.0, -1.0}), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::chemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/hemv/test_hemv.h b/gtestsuite/testsuite/level2/hemv/test_hemv.h index 8f8357e96e..a7243cbd2e 100644 --- a/gtestsuite/testsuite/level2/hemv/test_hemv.h +++ b/gtestsuite/testsuite/level2/hemv/test_hemv.h @@ -37,27 +37,25 @@ #include "hemv.h" #include "level2/ref_hemv.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_hemv( char storage, char uploa, char conja, char conjx, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, - double thresh, char datatype ) { - + T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, double thresh ) +{ // Compute the leading dimensions of a. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-3, 3, n, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y = testinghelpers::get_random_vector( -3, 3, n, incy ); - mkherm( storage, uploa, n, a.data(), lda ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_herm( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector y_ref(y); diff --git a/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp b/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp index 1f60f25468..81ee763b24 100644 --- a/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp +++ b/gtestsuite/testsuite/level2/hemv/zhemv_generic.cpp @@ -45,12 +45,11 @@ class zhemvTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zhemvTest, RandomData) { +TEST_P(zhemvTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,8 +76,6 @@ TEST_P(zhemvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = 8*std::sqrt(n)*testinghelpers::getEpsilon(); @@ -86,13 +83,13 @@ TEST_P(zhemvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_hemv(storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_hemv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } class zhemvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conja = std::get<2>(str.param); @@ -103,13 +100,12 @@ class zhemvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "zhemv_"; #elif TEST_CBLAS std::string str_name = "cblas_zhemv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zhemv"; + std::string str_name = "bli_zhemv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conja+conjx; @@ -125,7 +121,6 @@ class zhemvTestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -148,8 +143,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{2.0, -1.0}), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::zhemvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/her/cher_generic.cpp b/gtestsuite/testsuite/level2/her/cher_generic.cpp index 2805f17f23..8be6c2ed49 100644 --- a/gtestsuite/testsuite/level2/her/cher_generic.cpp +++ b/gtestsuite/testsuite/level2/her/cher_generic.cpp @@ -42,12 +42,11 @@ class cherTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cherTest, RandomData) { +TEST_P(cherTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -68,8 +67,6 @@ TEST_P(cherTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); @@ -77,13 +74,13 @@ TEST_P(cherTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her(storage, uploa, conjx, n, alpha, incx, lda_inc, thresh, datatype); + test_her( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } class cherTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -91,13 +88,12 @@ class cherTestPrint { float alpha = std::get<4>(str.param); gtint_t incx = std::get<5>(str.param); gtint_t ld_inc = std::get<6>(str.param); - char datatype = std::get<7>(str.param); #ifdef TEST_BLAS std::string str_name = "cher_"; #elif TEST_CBLAS std::string str_name = "cblas_cher"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cher"; + std::string str_name = "bli_cher"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx; @@ -107,7 +103,6 @@ class cherTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -127,8 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(1.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::cherTestPrint() ); diff --git a/gtestsuite/testsuite/level2/her/test_her.h b/gtestsuite/testsuite/level2/her/test_her.h index ad8a351eb1..db41652975 100644 --- a/gtestsuite/testsuite/level2/her/test_her.h +++ b/gtestsuite/testsuite/level2/her/test_her.h @@ -37,24 +37,23 @@ #include "her.h" #include "level2/ref_her.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_her( char storage, char uploa, char conjx, gtint_t n, Tr alpha, - gtint_t incx, gtint_t lda_inc, double thresh, char datatype ) { - + gtint_t incx, gtint_t lda_inc, double thresh ) +{ // Compute the leading dimensions of a. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector a_ref(a); diff --git a/gtestsuite/testsuite/level2/her/zher_generic.cpp b/gtestsuite/testsuite/level2/her/zher_generic.cpp index 902820d3ca..8db149caa5 100644 --- a/gtestsuite/testsuite/level2/her/zher_generic.cpp +++ b/gtestsuite/testsuite/level2/her/zher_generic.cpp @@ -42,12 +42,11 @@ class zherTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zherTest, RandomData) { +TEST_P(zherTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -68,8 +67,6 @@ TEST_P(zherTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = 4*std::sqrt(n)*testinghelpers::getEpsilon(); @@ -77,13 +74,13 @@ TEST_P(zherTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her(storage, uploa, conjx, n, alpha, incx, lda_inc, thresh, datatype); + test_her( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } class zherTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -91,13 +88,12 @@ class zherTestPrint { double alpha = std::get<4>(str.param); gtint_t incx = std::get<5>(str.param); gtint_t ld_inc = std::get<6>(str.param); - char datatype = std::get<7>(str.param); #ifdef TEST_BLAS std::string str_name = "zher_"; #elif TEST_CBLAS std::string str_name = "cblas_zher"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zher"; + std::string str_name = "bli_zher"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx; @@ -107,7 +103,6 @@ class zherTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -127,8 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(1.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::zherTestPrint() ); diff --git a/gtestsuite/testsuite/level2/her2/cher2_generic.cpp b/gtestsuite/testsuite/level2/her2/cher2_generic.cpp index 7c7f16bf72..f6bbd15a06 100644 --- a/gtestsuite/testsuite/level2/her2/cher2_generic.cpp +++ b/gtestsuite/testsuite/level2/her2/cher2_generic.cpp @@ -44,12 +44,11 @@ class cher2Test : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cher2Test, RandomData) { +TEST_P(cher2Test, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,8 +73,6 @@ TEST_P(cher2Test, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = 4*n*testinghelpers::getEpsilon(); @@ -83,13 +80,13 @@ TEST_P(cher2Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2(storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_her2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } class cher2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -99,13 +96,12 @@ class cher2TestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "cher2_"; #elif TEST_CBLAS std::string str_name = "cblas_cher2"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cher2"; + std::string str_name = "bli_cher2"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx+conjy; @@ -118,7 +114,6 @@ class cher2TestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0, -2.0}), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::cher2TestPrint() ); diff --git a/gtestsuite/testsuite/level2/her2/test_her2.h b/gtestsuite/testsuite/level2/her2/test_her2.h index 10814b90db..b0802d64b4 100644 --- a/gtestsuite/testsuite/level2/her2/test_her2.h +++ b/gtestsuite/testsuite/level2/her2/test_her2.h @@ -37,27 +37,25 @@ #include "her2.h" #include "level2/ref_her2.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_her2( char storage, char uploa, char conjx, char conjy, gtint_t n, - T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh, - char datatype ) { - + T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh ) +{ // Compute the leading dimensions of a. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-2, 5, n, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y = testinghelpers::get_random_vector( -2, 5, n, incy ); - mkherm( storage, uploa, n, a.data(), lda ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_herm( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector a_ref(a); diff --git a/gtestsuite/testsuite/level2/her2/zher2_generic.cpp b/gtestsuite/testsuite/level2/her2/zher2_generic.cpp index c7bc0bcd9a..acd8b4465a 100644 --- a/gtestsuite/testsuite/level2/her2/zher2_generic.cpp +++ b/gtestsuite/testsuite/level2/her2/zher2_generic.cpp @@ -44,12 +44,11 @@ class zher2Test : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zher2Test, RandomData) { +TEST_P(zher2Test, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,8 +73,6 @@ TEST_P(zher2Test, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = 6*std::sqrt(n)*testinghelpers::getEpsilon(); @@ -83,13 +80,13 @@ TEST_P(zher2Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2(storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_her2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } class zher2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -99,13 +96,12 @@ class zher2TestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "zher2_"; #elif TEST_CBLAS std::string str_name = "cblas_zher2"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zher2"; + std::string str_name = "bli_zher2"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx+conjy; @@ -118,7 +114,6 @@ class zher2TestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0, -2.0}), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::zher2TestPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp b/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp index a8ca008deb..a62f20996d 100644 --- a/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp +++ b/gtestsuite/testsuite/level2/symv/dsymv_generic.cpp @@ -45,12 +45,11 @@ class dsymvTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsymvTest, RandomData) { +TEST_P(dsymvTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,8 +76,6 @@ TEST_P(dsymvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -86,13 +83,13 @@ TEST_P(dsymvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symv(storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_symv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } class dsymvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conja = std::get<2>(str.param); @@ -103,13 +100,12 @@ class dsymvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "dsymv_"; #elif TEST_CBLAS std::string str_name = "cblas_dsymv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsymv"; + std::string str_name = "bli_dsymv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conja+conjx; @@ -123,7 +119,6 @@ class dsymvTestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -146,8 +141,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( 2.0, -1.0 ), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::dsymvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp b/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp index 498a7b89c9..d83d75b7dc 100644 --- a/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp +++ b/gtestsuite/testsuite/level2/symv/ssymv_generic.cpp @@ -45,12 +45,11 @@ class ssymvTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssymvTest, RandomData) { +TEST_P(ssymvTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -77,8 +76,6 @@ TEST_P(ssymvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -86,13 +83,13 @@ TEST_P(ssymvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symv(storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh, datatype); + test_symv( storage, uploa, conja, conjx, n, alpha, lda_inc, incx, beta, incy, thresh ); } class ssymvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conja = std::get<2>(str.param); @@ -103,13 +100,12 @@ class ssymvTestPrint { gtint_t incx = std::get<7>(str.param); gtint_t incy = std::get<8>(str.param); gtint_t ld_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "ssymv_"; #elif TEST_CBLAS std::string str_name = "cblas_ssymv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssymv"; + std::string str_name = "bli_ssymv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conja+conjx; @@ -123,7 +119,6 @@ class ssymvTestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -146,8 +141,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( 2.0, -1.0 ), // beta ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::ssymvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/symv/test_symv.h b/gtestsuite/testsuite/level2/symv/test_symv.h index 22c556d346..f0df77c18b 100644 --- a/gtestsuite/testsuite/level2/symv/test_symv.h +++ b/gtestsuite/testsuite/level2/symv/test_symv.h @@ -37,27 +37,25 @@ #include "symv.h" #include "level2/ref_symv.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_symv( char storage, char uploa, char conja, char conjx, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, - double thresh, char datatype ) { - + T alpha, gtint_t lda_inc, gtint_t incx, T beta, gtint_t incy, double thresh ) +{ // Compute the leading dimensions of a. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-2, 5, n, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y = testinghelpers::get_random_vector( -2, 5, n, incy ); - mksymm( storage, uploa, n, a.data(), lda ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_symm( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector y_ref(y); diff --git a/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp b/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp index d80e990298..3d755586a8 100644 --- a/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp +++ b/gtestsuite/testsuite/level2/syr/dsyr_generic.cpp @@ -42,12 +42,11 @@ class dsyrTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsyrTest, RandomData) { +TEST_P(dsyrTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -68,8 +67,6 @@ TEST_P(dsyrTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = 2*n*testinghelpers::getEpsilon(); @@ -77,13 +74,13 @@ TEST_P(dsyrTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr(storage, uploa, conjx, n, alpha, incx, lda_inc, thresh, datatype); + test_syr( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } class dsyrTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -91,13 +88,12 @@ class dsyrTestPrint { double alpha = std::get<4>(str.param); gtint_t incx = std::get<5>(str.param); gtint_t ld_inc = std::get<6>(str.param); - char datatype = std::get<7>(str.param); #ifdef TEST_BLAS std::string str_name = "dsyr_"; #elif TEST_CBLAS std::string str_name = "cblas_dsyr"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsyr"; + std::string str_name = "bli_dsyr"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx; @@ -107,7 +103,6 @@ class dsyrTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -127,8 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(1.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::dsyrTestPrint() ); diff --git a/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp b/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp index 9e44b518f6..446c2f4743 100644 --- a/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp +++ b/gtestsuite/testsuite/level2/syr/ssyr_generic.cpp @@ -42,12 +42,11 @@ class ssyrTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssyrTest, RandomData) { +TEST_P(ssyrTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -68,8 +67,6 @@ TEST_P(ssyrTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<6>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<7>(GetParam()); // Set the threshold for the errors: double thresh = 2*n*testinghelpers::getEpsilon(); @@ -77,13 +74,13 @@ TEST_P(ssyrTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr(storage, uploa, conjx, n, alpha, incx, lda_inc, thresh, datatype); + test_syr( storage, uploa, conjx, n, alpha, incx, lda_inc, thresh ); } class ssyrTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -91,13 +88,12 @@ class ssyrTestPrint { float alpha = std::get<4>(str.param); gtint_t incx = std::get<5>(str.param); gtint_t ld_inc = std::get<6>(str.param); - char datatype = std::get<7>(str.param); #ifdef TEST_BLAS std::string str_name = "ssyr_"; #elif TEST_CBLAS std::string str_name = "cblas_ssyr"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssyr"; + std::string str_name = "bli_ssyr"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx; @@ -107,7 +103,6 @@ class ssyrTestPrint { std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : ("m" + std::to_string(int(std::abs(alpha)))); str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -127,8 +122,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(1.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::ssyrTestPrint() ); diff --git a/gtestsuite/testsuite/level2/syr/test_syr.h b/gtestsuite/testsuite/level2/syr/test_syr.h index d8cc9e9ada..125445fa19 100644 --- a/gtestsuite/testsuite/level2/syr/test_syr.h +++ b/gtestsuite/testsuite/level2/syr/test_syr.h @@ -37,24 +37,23 @@ #include "syr.h" #include "level2/ref_syr.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_syr( char storage, char uploa, char conjx, gtint_t n, T alpha, - gtint_t incx, gtint_t lda_inc, double thresh, char datatype ) { - + gtint_t incx, gtint_t lda_inc, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector a_ref(a); diff --git a/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp b/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp index 896323648c..2a021ea6d8 100644 --- a/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp +++ b/gtestsuite/testsuite/level2/syr2/dsyr2_generic.cpp @@ -44,12 +44,11 @@ class dsyr2Test : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsyr2Test, RandomData) { +TEST_P(dsyr2Test, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,8 +73,6 @@ TEST_P(dsyr2Test, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = 3*n*testinghelpers::getEpsilon(); @@ -83,13 +80,13 @@ TEST_P(dsyr2Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2(storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_syr2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } class dsyr2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -99,13 +96,12 @@ class dsyr2TestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "dsyr2_"; #elif TEST_CBLAS std::string str_name = "cblas_dsyr2"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsyr2"; + std::string str_name = "bli_dsyr2"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx+conjy; @@ -117,7 +113,6 @@ class dsyr2TestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -139,8 +134,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, -2.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::dsyr2TestPrint() ); diff --git a/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp b/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp index ced6dfdd89..75df2d0367 100644 --- a/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp +++ b/gtestsuite/testsuite/level2/syr2/ssyr2_generic.cpp @@ -44,12 +44,11 @@ class ssyr2Test : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssyr2Test, RandomData) { +TEST_P(ssyr2Test, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -74,8 +73,6 @@ TEST_P(ssyr2Test, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = 3*n*testinghelpers::getEpsilon(); @@ -83,13 +80,13 @@ TEST_P(ssyr2Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2(storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh, datatype); + test_syr2( storage, uploa, conjx, conjy, n, alpha, incx, incy, lda_inc, thresh ); } class ssyr2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char conjx = std::get<2>(str.param); @@ -99,13 +96,12 @@ class ssyr2TestPrint { gtint_t incx = std::get<6>(str.param); gtint_t incy = std::get<7>(str.param); gtint_t ld_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "ssyr2_"; #elif TEST_CBLAS std::string str_name = "cblas_ssyr2"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssyr2"; + std::string str_name = "bli_ssyr2"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+conjx+conjy; @@ -117,7 +113,6 @@ class ssyr2TestPrint { str_name = str_name + "_" + incx_str; str_name = str_name + "_" + incy_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -139,8 +134,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(1.0, -2.0), // alpha ::testing::Values(gtint_t(1)), // stride size for x ::testing::Values(gtint_t(1)), // stride size for y - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::ssyr2TestPrint() ); diff --git a/gtestsuite/testsuite/level2/syr2/test_syr2.h b/gtestsuite/testsuite/level2/syr2/test_syr2.h index 92b8b64baa..a4a623b6ea 100644 --- a/gtestsuite/testsuite/level2/syr2/test_syr2.h +++ b/gtestsuite/testsuite/level2/syr2/test_syr2.h @@ -37,27 +37,25 @@ #include "syr2.h" #include "level2/ref_syr2.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_syr2( char storage, char uploa, char conjx, char conjy, gtint_t n, - T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh, - char datatype ) { - + T alpha, gtint_t incx, gtint_t incy, gtint_t lda_inc, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 5, storage, 'n', n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-3, 3, n, incx, datatype); - std::vector y = testinghelpers::get_random_vector(-3, 3, n, incy, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 5, storage, 'n', n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -3, 3, n, incx ); + std::vector y = testinghelpers::get_random_vector( -3, 3, n, incy ); - mksymm( storage, uploa, n, a.data(), lda ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_symm( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector a_ref(a); diff --git a/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp index 61f048c70d..a82fafcc2b 100644 --- a/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp +++ b/gtestsuite/testsuite/level2/trmv/ctrmv_generic.cpp @@ -43,12 +43,11 @@ class ctrmvTest : gtint_t, scomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ctrmvTest, RandomData) { +TEST_P(ctrmvTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(ctrmvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(ctrmvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class ctrmvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class ctrmvTestPrint { scomplex alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "ctrmv_"; #elif TEST_CBLAS std::string str_name = "cblas_ctrmv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ctrmv"; + std::string str_name = "bli_ctrmv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -113,7 +109,6 @@ class ctrmvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of a ), ::ctrmvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp index 869cc69744..e7e9e325b9 100644 --- a/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp +++ b/gtestsuite/testsuite/level2/trmv/dtrmv_generic.cpp @@ -43,12 +43,11 @@ class dtrmvTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dtrmvTest, RandomData) { +TEST_P(dtrmvTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(dtrmvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 20*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(dtrmvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class dtrmvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class dtrmvTestPrint { double alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "dtrmv_"; #elif TEST_CBLAS std::string str_name = "cblas_dtrmv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dtrmv"; + std::string str_name = "bli_dtrmv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -112,7 +108,6 @@ class dtrmvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -137,8 +132,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::dtrmvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp index 18bbd93b77..470e556814 100644 --- a/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp +++ b/gtestsuite/testsuite/level2/trmv/strmv_generic.cpp @@ -43,12 +43,11 @@ class strmvTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(strmvTest, RandomData) { +TEST_P(strmvTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(strmvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(strmvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class strmvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class strmvTestPrint { float alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "strmv_"; #elif TEST_CBLAS std::string str_name = "cblas_strmv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_strmv"; + std::string str_name = "bli_strmv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -112,7 +108,6 @@ class strmvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -137,8 +132,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of a ), ::strmvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trmv/test_trmv.h b/gtestsuite/testsuite/level2/trmv/test_trmv.h index 82d8b0d6a3..d59f4412f7 100644 --- a/gtestsuite/testsuite/level2/trmv/test_trmv.h +++ b/gtestsuite/testsuite/level2/trmv/test_trmv.h @@ -37,24 +37,23 @@ #include "trmv.h" #include "level2/ref_trmv.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_trmv( char storage, char uploa, char transa, char diaga, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, double thresh, char datatype ) { - + T alpha, gtint_t lda_inc, gtint_t incx, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, transa, n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, n, n, lda ); + std::vector x = testinghelpers::get_random_vector( -10, 10, n, incx ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector x_ref(x); diff --git a/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp b/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp index 759202433d..1fb53d2b7d 100644 --- a/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp +++ b/gtestsuite/testsuite/level2/trmv/ztrmv_generic.cpp @@ -43,12 +43,11 @@ class ztrmvTest : gtint_t, dcomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ztrmvTest, RandomData) { +TEST_P(ztrmvTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(ztrmvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(ztrmvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trmv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class ztrmvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class ztrmvTestPrint { dcomplex alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "ztrmv_"; #elif TEST_CBLAS std::string str_name = "cblas_ztrmv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ztrmv"; + std::string str_name = "bli_ztrmv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -113,7 +109,6 @@ class ztrmvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of a ), ::ztrmvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp index 45421b8f97..1639e7202c 100644 --- a/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp +++ b/gtestsuite/testsuite/level2/trsv/ctrsv_generic.cpp @@ -43,12 +43,11 @@ class ctrsvTest : gtint_t, scomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ctrsvTest, RandomData) { +TEST_P(ctrsvTest, RandomData) +{ using T = scomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(ctrsvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 5*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(ctrsvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class ctrsvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class ctrsvTestPrint { scomplex alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "ctrsv_"; #elif TEST_CBLAS std::string str_name = "cblas_ctrsv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ctrsv"; + std::string str_name = "bli_ctrsv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -113,7 +109,6 @@ class ctrsvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of a ), ::ctrsvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp index 2a4e1c6cac..3ebf2f6076 100644 --- a/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp +++ b/gtestsuite/testsuite/level2/trsv/dtrsv_generic.cpp @@ -43,12 +43,11 @@ class dtrsvTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dtrsvTest, RandomData) { +TEST_P(dtrsvTest, RandomData) +{ using T = double; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(dtrsvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 100*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(dtrsvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class dtrsvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class dtrsvTestPrint { double alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "dtrsv_"; #elif TEST_CBLAS std::string str_name = "cblas_dtrsv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dtrsv"; + std::string str_name = "bli_dtrsv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -112,7 +108,6 @@ class dtrsvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -137,8 +132,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of a ), ::dtrsvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp index edd0197070..201223b134 100644 --- a/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp +++ b/gtestsuite/testsuite/level2/trsv/strsv_generic.cpp @@ -43,12 +43,11 @@ class strsvTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(strsvTest, RandomData) { +TEST_P(strsvTest, RandomData) +{ using T = float; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(strsvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 20*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(strsvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class strsvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class strsvTestPrint { float alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "strsv_"; #elif TEST_CBLAS std::string str_name = "cblas_strsv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_strsv"; + std::string str_name = "bli_strsv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -112,7 +108,6 @@ class strsvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -137,8 +132,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of a ), ::strsvTestPrint() ); diff --git a/gtestsuite/testsuite/level2/trsv/test_trsv.h b/gtestsuite/testsuite/level2/trsv/test_trsv.h index 320459c862..2266397200 100644 --- a/gtestsuite/testsuite/level2/trsv/test_trsv.h +++ b/gtestsuite/testsuite/level2/trsv/test_trsv.h @@ -37,24 +37,23 @@ #include "trsv.h" #include "level2/ref_trsv.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_trsv( char storage, char uploa, char transa, char diaga, gtint_t n, - T alpha, gtint_t lda_inc, gtint_t incx, double thresh, char datatype ) { - + T alpha, gtint_t lda_inc, gtint_t incx, double thresh ) +{ // Compute the leading dimensions for matrix size calculation. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, n, n, lda_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, n, n, lda_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(1, 5, storage, transa, n, n, lda, datatype); - std::vector x = testinghelpers::get_random_vector(1, 3, n, incx, datatype); + std::vector a = testinghelpers::get_random_matrix( 1, 5, storage, transa, n, n, lda ); + std::vector x = testinghelpers::get_random_vector( 1, 3, n, incx ); - mktrim( storage, uploa, n, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, n, a.data(), lda ); // Create a copy of c so that we can check reference results. std::vector x_ref(x); diff --git a/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp b/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp index e3232f0229..dc8b004575 100644 --- a/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp +++ b/gtestsuite/testsuite/level2/trsv/ztrsv_generic.cpp @@ -43,12 +43,11 @@ class ztrsvTest : gtint_t, dcomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ztrsvTest, RandomData) { +TEST_P(ztrsvTest, RandomData) +{ using T = dcomplex; - //---------------------------------------------------------- // Initialize values from the parameters passed through // test suite instantiation (INSTANTIATE_TEST_SUITE_P). @@ -71,8 +70,6 @@ TEST_P(ztrsvTest, RandomData) { // If increment is zero, then the array size matches the matrix size. // If increment are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<8>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*testinghelpers::getEpsilon(); @@ -80,13 +77,13 @@ TEST_P(ztrsvTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsv(storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh, datatype); + test_trsv( storage, uploa, transa, diaga, n, alpha, lda_inc, incx, thresh ); } class ztrsvTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uploa = std::get<1>(str.param); char transa = std::get<2>(str.param); @@ -95,13 +92,12 @@ class ztrsvTestPrint { dcomplex alpha = std::get<5>(str.param); gtint_t incx = std::get<6>(str.param); gtint_t ld_inc = std::get<7>(str.param); - char datatype = std::get<8>(str.param); #ifdef TEST_BLAS std::string str_name = "ztrsv_"; #elif TEST_CBLAS std::string str_name = "cblas_ztrsv"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ztrsv"; + std::string str_name = "bli_ztrsv"; #endif str_name = str_name + "_" + sfm; str_name = str_name + "_" + uploa+transa; @@ -113,7 +109,6 @@ class ztrsvTestPrint { std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; str_name = str_name + "_" + std::to_string(ld_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,8 +133,7 @@ INSTANTIATE_TEST_SUITE_P( #endif ), // alpha ::testing::Values(gtint_t(1)), // stride size for x - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of a ), ::ztrsvTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp b/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp new file mode 100644 index 0000000000..9e8ea79d4e --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/IIT_ERS_test.cpp @@ -0,0 +1,264 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "gemm.h" +#include "inc/check_error.h" +#include "common/wrong_inputs_helpers.h" + +template +class Gemm_IIT_ERS_Test : public ::testing::Test {}; +typedef ::testing::Types TypeParam; // The supported datatypes from BLAS calls for GEMM +TYPED_TEST_SUITE(Gemm_IIT_ERS_Test, TypeParam); // Defining individual testsuites based on the datatype support. + +// Adding namespace to get default parameters(valid case) from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +#ifdef TEST_BLAS + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM): + 1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' (info = 1) + 2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' (info = 2) + 3. When m < 0 (info = 3) + 4. When n < 0 (info = 4) + 5. When k < 0 (info = 5) + 6. When lda < max(1, thresh) (info = 8), thresh set based on TRANSA value + 7. When ldb < max(1, thresh) (info = 10), thresh set based on TRANSB value + 8. When ldc < max(1, n) (info = 13) + +*/ + +// When info == 1 +TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transa) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm( STORAGE, 'p', TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 2 +TYPED_TEST(Gemm_IIT_ERS_Test, invalid_transb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for B. + gemm( STORAGE, TRANS, 'p', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 3 +TYPED_TEST(Gemm_IIT_ERS_Test, m_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm( STORAGE, TRANS, TRANS, -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 4 +TYPED_TEST(Gemm_IIT_ERS_Test, n_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for n. + gemm( STORAGE, TRANS, TRANS, M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 5 +TYPED_TEST(Gemm_IIT_ERS_Test, k_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for k. + gemm( STORAGE, TRANS, TRANS, M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 8 +TYPED_TEST(Gemm_IIT_ERS_Test, invalid_lda) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for lda. + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 10 +TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for ldb. + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 13 +TYPED_TEST(Gemm_IIT_ERS_Test, invalid_ldc) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for ldc. + gemm( STORAGE, TRANS, TRANS, M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +/* + Early Return Scenarios(ERS) : + + The GEMM API is expected to return early in the following cases: + + 1. When m == 0. + 2. When n == 0. + 3. When (alpha == 0 or k == 0) and beta == 1. + +*/ + +// When m is 0 +TYPED_TEST(Gemm_IIT_ERS_Test, m_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + gemm( STORAGE, TRANS, TRANS, 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When n is 0 +TYPED_TEST(Gemm_IIT_ERS_Test, n_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + gemm( STORAGE, TRANS, TRANS, M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When alpha is 0 and beta is 1 +TYPED_TEST(Gemm_IIT_ERS_Test, alpha_zero_beta_one) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + T alpha, beta; + + testinghelpers::initzero( alpha ); + testinghelpers::initone( beta ); + + gemm( STORAGE, TRANS, TRANS, M, N, K, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When k is 0 and beta is 1 +TYPED_TEST(Gemm_IIT_ERS_Test, k_zero_beta_one) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + T alpha, beta; + + testinghelpers::initone( alpha ); + testinghelpers::initone( beta ); + + gemm( STORAGE, TRANS, TRANS, M, N, 0, &alpha, nullptr, LDA, nullptr, LDB, &beta, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + + +#endif diff --git a/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp index fa6b10006a..5043dc44a7 100644 --- a/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm/cgemm_generic.cpp @@ -46,10 +46,10 @@ class CGemmTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(CGemmTest, RandomData) { +TEST_P(CGemmTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(CGemmTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*n*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(CGemmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemm(storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class CGemmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -104,7 +102,6 @@ class CGemmTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "cgemm_"; #elif TEST_CBLAS @@ -126,7 +123,6 @@ class CGemmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -150,8 +146,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0,2.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), ::CGemmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp index 5a7bcbd910..8d07668cc4 100644 --- a/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm/dgemm_generic.cpp @@ -46,10 +46,10 @@ class DGemmTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(DGemmTest, RandomData) { +TEST_P(DGemmTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,22 +77,20 @@ TEST_P(DGemmTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = 10*m*n*k*testinghelpers::getEpsilon(); + double thresh = 10*m*n*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemm(storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class DGemmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -104,13 +102,12 @@ class DGemmTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "dgemm_"; #elif TEST_CBLAS std::string str_name = "cblas_dgemm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dgemm"; + std::string str_name = "bli_dgemm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + tsa + tsb; @@ -124,7 +121,6 @@ class DGemmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -148,8 +144,167 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + + +// Tests 5 loops +INSTANTIATE_TEST_SUITE_P( + tiny_dgemm_kernel, + DGemmTest, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m + + ::testing::Values(8, 48, 72, 144, 237), // n + + ::testing::Values(16, 24, 48, 64, 128, 557), // k + // No condition based on alpha + ::testing::Values( -1.0), // alpha + // No condition based on betaa + ::testing::Values(-1.0), // beta + ::testing::Values(0,3), // increment to the leading dim of a + ::testing::Values(0,3), // increment to the leading dim of b + ::testing::Values(0,3) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + +//zero beta test case +INSTANTIATE_TEST_SUITE_P( + zero_beta, + DGemmTest, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m + + ::testing::Values(8, 48, 72, 144, 237), // n + + ::testing::Values(16, 24, 48, 64, 128, 557), // k + + ::testing::Values( -1.0), // alpha + ::testing::Values(0.0), // beta + ::testing::Values(0,3), // increment to the leading dim of a + ::testing::Values(0,3), // increment to the leading dim of b + ::testing::Values(0,3) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + +//zero alpha test case +INSTANTIATE_TEST_SUITE_P( + zero_alpha, + DGemmTest, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m + + ::testing::Values(8, 48, 72, 144, 237), // n + + ::testing::Values(16, 24, 48, 64, 128, 557), // k + + ::testing::Values( 0.0), // alpha + ::testing::Values(-1.0), // beta + ::testing::Values(0,3), // increment to the leading dim of a + ::testing::Values(0,3), // increment to the leading dim of b + ::testing::Values(0,3) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + +//unit beta test case +INSTANTIATE_TEST_SUITE_P( + unit_beta, + DGemmTest, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(13, 25, 48, 60, 256, 512, 1000), // m + + ::testing::Values(8, 48, 72, 144, 237), // n + + ::testing::Values(16, 24, 48, 64, 128, 557), // k + + ::testing::Values( -1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(0,3), // increment to the leading dim of a + ::testing::Values(0,3), // increment to the leading dim of b + ::testing::Values(0,3) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + +// Covers all corner cases of tiny dgemm kernel +INSTANTIATE_TEST_SUITE_P( + tiny_edge_kernels, + DGemmTest, + ::testing::Combine( + // To test col storage of C + // Storage of A and B is handled by packing + ::testing::Values('c'), // storage format + // Tests scalar code of 8xk and 6xk pack kernels for both storage formats + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + + ::testing::Range(gtint_t(1), gtint_t(23), 1), // m + ::testing::Range(gtint_t(1), gtint_t(7), 1), // n + + ::testing::Values(24), // k + // No condition based on alpha + ::testing::Values( -1.0, 1.0), // alpha + // checks for beta-zero and beta non-zero cases + ::testing::Values(0.0, 1.0, -1.0), // beta + ::testing::Values(23), // increment to the leading dim of a + ::testing::Values(23), // increment to the leading dim of b + ::testing::Values(23) // increment to the leading dim of c + ), + ::DGemmTestPrint() + ); + + +//m = 0, n = 0 k = 0 testcase +INSTANTIATE_TEST_SUITE_P( + mnkzero, + DGemmTest, + ::testing::Combine( + // No condition based on storage scheme of matrices + ::testing::Values('c'), // storage format + // No conditions based on trans of matrices + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + + ::testing::Values(0, 8, 24), // m + + ::testing::Values(0, 6, 8), // n + + ::testing::Values(3), // k + + ::testing::Values( -1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(0,3), // increment to the leading dim of a + ::testing::Values(0,3), // increment to the leading dim of b + ::testing::Values(0,3) // increment to the leading dim of c ), ::DGemmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp index f1f7bec8cf..2adbe2968a 100644 --- a/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm/sgemm_generic.cpp @@ -46,10 +46,10 @@ class SGemmTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(SGemmTest, RandomData) { +TEST_P(SGemmTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(SGemmTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*n*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(SGemmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemm(storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class SGemmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -104,13 +102,12 @@ class SGemmTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "sgemm_"; #elif TEST_CBLAS std::string str_name = "cblas_sgemm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_sgemm"; + std::string str_name = "bli_sgemm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + tsa + tsb; @@ -124,14 +121,13 @@ class SGemmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; // Black box testing. INSTANTIATE_TEST_SUITE_P( - Blackbox, + sgemm_sup_10_30, SGemmTest, ::testing::Combine( ::testing::Values('c' @@ -148,8 +144,56 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of c + ), + ::SGemmTestPrint() + ); + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + sgemm_sup_alpha_beta, + SGemmTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n','t'), // transa + ::testing::Values('n','t'), // transb + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(50), 1), // n + ::testing::Range(gtint_t(1), gtint_t(10), 1), // k + ::testing::Values(0.0, 1.0, -1.0, 5.3, -10.0), // alpha + ::testing::Values(0.0, 1.0, -1.0, 6.4, -19.0), // beta + ::testing::Values(gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)), // increment to the leading dim of b + ::testing::Values(gtint_t(7)) // increment to the leading dim of c + ), + ::SGemmTestPrint() + ); + + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + sgemm_sup_m_n_k_100, + SGemmTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Range(gtint_t(1), gtint_t(20), 1), // m + ::testing::Range(gtint_t(1), gtint_t(50), 1), // n + ::testing::Range(gtint_t(1), gtint_t(20), 1), // k + ::testing::Values( -2.0), // alpha + ::testing::Values( 5.0), // beta + ::testing::Values(gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(3)), // increment to the leading dim of b + ::testing::Values(gtint_t(7)) // increment to the leading dim of c ), ::SGemmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm/test_gemm.h b/gtestsuite/testsuite/level3/gemm/test_gemm.h index 3396ba2ce6..147bcdab50 100644 --- a/gtestsuite/testsuite/level3/gemm/test_gemm.h +++ b/gtestsuite/testsuite/level3/gemm/test_gemm.h @@ -40,23 +40,81 @@ #include #include - template void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, - gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, T beta, double thresh, char datatype ) { + gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, + T beta, double thresh ) +{ + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + + // Create a copy of c so that we can check reference results. + std::vector c_ref(c); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemm( storage, trnsa, trnsb, m, n, k, &alpha, a.data(), lda, + b.data(), ldb, &beta, c.data(), ldc ); + + //---------------------------------------------------------- + // Call reference implementation. + //---------------------------------------------------------- + testinghelpers::ref_gemm( storage, trnsa, trnsb, m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); +} +// Test body used for exception value testing, by inducing an exception value +// in the index that is passed for each of the matrices. +/* + (ai, aj) is the index with corresponding exception value aexval in matrix A. + The index is with respect to the assumption that the matrix is column stored, + without any transpose. In case of the row-storage and/or transpose, the index + is translated from its assumption accordingly. + Ex : (2, 3) with storage 'c' and transpose 'n' becomes (3, 2) if storage becomes + 'r' or transpose becomes 't'. +*/ +// (bi, bj) is the index with corresponding exception value bexval in matrix B. +// (ci, cj) is the index with corresponding exception value cexval in matrix C. +template +void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, + gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, + T beta, gtint_t ai, gtint_t aj, T aexval, gtint_t bi, gtint_t bj, T bexval, + gtint_t ci, gtint_t cj, T cexval, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, trnsa, m, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, trnsb, k, n, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random numbers //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, trnsa, m, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, trnsb, k, n, ldb, datatype); - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, 'n', m, n, ldc, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + + // Inducing exception values onto the matrices based on the indices passed as arguments. + // Assumption is that the indices are with respect to the matrices in column storage without + // any transpose. In case of difference in storage scheme or transposition, the row and column + // indices are appropriately swapped. + testinghelpers::set_ev_mat( storage, trnsa, lda, ai, aj, aexval, a.data() ); + testinghelpers::set_ev_mat( storage, trnsb, ldb, bi, bj, bexval, b.data() ); + testinghelpers::set_ev_mat( storage, 'n', ldc, ci, cj, cexval, c.data() ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -76,5 +134,5 @@ void test_gemm( char storage, char trnsa, char trnsb, gtint_t m, gtint_t n, //---------------------------------------------------------- // check component-wise error. //---------------------------------------------------------- - computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); + computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh, true ); } diff --git a/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp b/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp new file mode 100644 index 0000000000..3b0f05ab9b --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm/zgemm_evt_testing.cpp @@ -0,0 +1,356 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +/* + The following file contains both the exception value testing(EVT) and the + positive accuracy testing of the bli_zgemm_4x4_avx2_k1_nn( ... ) computational + kernel. This kernel is invoked from the BLAS layer, and inputs are given + in a manner so as to avoid the other code-paths and test only the required + kernel. + +*/ + +#include +#include "test_gemm.h" + +class ZGemmEVTTest : + public ::testing::TestWithParam> {}; + +TEST_P(ZGemmEVTTest, Unit_Tester) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // matrix size m + gtint_t m = std::get<3>(GetParam()); + // matrix size n + gtint_t n = std::get<4>(GetParam()); + // matrix size k + gtint_t k = std::get<5>(GetParam()); + + gtint_t ai, aj, bi, bj, ci, cj; + T aex, bex, cex; + ai = std::get<6>(GetParam()); + aj = std::get<7>(GetParam()); + aex = std::get<8>(GetParam()); + + bi = std::get<9>(GetParam()); + bj = std::get<10>(GetParam()); + bex = std::get<11>(GetParam()); + + ci = std::get<12>(GetParam()); + cj = std::get<13>(GetParam()); + cex = std::get<14>(GetParam()); + + // specifies alpha value + T alpha = std::get<15>(GetParam()); + // specifies beta value + T beta = std::get<16>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<17>(GetParam()); + gtint_t ldb_inc = std::get<18>(GetParam()); + gtint_t ldc_inc = std::get<19>(GetParam()); + + // Set the threshold for the errors: + double thresh = 10*m*n*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, + alpha, beta, ai, aj, aex, bi, bj, bex, ci, cj, cex, thresh ); +} + +// Helper classes for printing the test case parameters based on the instantiator +// These are mainly used to help with debugging, in case of failures + +// Utility to print the test-case in case of exception value on matrices +class ZGemmEVMatPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + gtint_t ai, aj, bi, bj, ci, cj; + dcomplex aex, bex, cex; + ai = std::get<6>(str.param); + aj = std::get<7>(str.param); + aex = std::get<8>(str.param); + + bi = std::get<9>(str.param); + bj = std::get<10>(str.param); + bex = std::get<11>(str.param); + + ci = std::get<12>(str.param); + cj = std::get<13>(str.param); + cex = std::get<14>(str.param); + + dcomplex alpha = std::get<15>(str.param); + dcomplex beta = std::get<16>(str.param); + gtint_t lda_inc = std::get<17>(str.param); + gtint_t ldb_inc = std::get<18>(str.param); + gtint_t ldc_inc = std::get<19>(str.param); + +#ifdef TEST_BLAS + std::string str_name = "zgemm_"; +#elif TEST_CBLAS + std::string str_name = "cblas_zgemm"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "blis_zgemm"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + str_name = str_name + "_A" + std::to_string(ai) + std::to_string(aj); + str_name = str_name + "_" + testinghelpers::get_value_string(aex); + str_name = str_name + "_B" + std::to_string(bi) + std::to_string(bj); + str_name = str_name + "_" + testinghelpers::get_value_string(bex); + str_name = str_name + "_C" + std::to_string(ci) + std::to_string(cj); + str_name = str_name + "_" + testinghelpers::get_value_string(cex); + str_name = str_name + "_a" + testinghelpers::get_value_string(alpha); + str_name = str_name + "_b" + testinghelpers::get_value_string(beta); + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +// Utility to print the test-case in case of exception value on matrices +class ZGemmEVAlphaBetaPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + gtint_t m = std::get<3>(str.param); + gtint_t n = std::get<4>(str.param); + gtint_t k = std::get<5>(str.param); + + dcomplex alpha = std::get<15>(str.param); + dcomplex beta = std::get<16>(str.param); + gtint_t lda_inc = std::get<17>(str.param); + gtint_t ldb_inc = std::get<18>(str.param); + gtint_t ldc_inc = std::get<19>(str.param); + +#ifdef TEST_BLAS + std::string str_name = "zgemm_"; +#elif TEST_CBLAS + std::string str_name = "cblas_zgemm"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "blis_zgemm"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + str_name = str_name + "_a" + testinghelpers::get_value_string(alpha); + str_name = str_name + "_b" + testinghelpers::get_value_string(beta); + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +// Exception value testing(on matrices) + +/* + For the bli_zgemm_4x4_avx2_k1_nn kernel, the main and fringe dimensions are as follows: + For m : Main = { 4 }, fringe = { 2, 1 } + For n : Main = { 4 }, fringe = { 2, 1 } + + Without any changes to the BLAS layer in BLIS, the fringe case of 1 cannot be touched + separately, since if m/n is 1, the inputs are redirected to ZGEMV. + +*/ + +// Testing for the main loop case for m and n +// The kernel uses 2 loads and 4 broadcasts. The exception values +// are induced at one index individually for each of the loads. +// They are also induced in the broadcast direction at two places. +INSTANTIATE_TEST_SUITE_P( + bli_zgemm_4x4_avx2_k1_nn_evt_mat_main, + ZGemmEVTTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(4)), // m + ::testing::Values(gtint_t(4)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(1), gtint_t(3)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(2)), // bj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval + ::testing::Values(gtint_t(0), gtint_t(2)), // ci + ::testing::Values(gtint_t(1), gtint_t(3)), // cj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval + ::testing::Values(dcomplex{-2.2, 3.3}), // alpha + ::testing::Values(dcomplex{1.2, -2.3}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::ZGemmEVMatPrint() + ); + +// Testing the fringe cases +// Fringe case minimum size is 2 along both m and n. +// Invloves only one load(AVX2 or (AVX2+SSE)). Thus, +// the exception values are induced at the first and second indices of the +// column vector A and row vector B. +INSTANTIATE_TEST_SUITE_P( + bli_zgemm_4x4_avx2_k1_nn_evt_mat_fringe, + ZGemmEVTTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(3)), // m + ::testing::Values(gtint_t(2), gtint_t(3)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(0), gtint_t(1)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // aexval + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0), gtint_t(1)), // bj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // bexval + ::testing::Values(gtint_t(0), gtint_t(1)), // ci + ::testing::Values(gtint_t(0), gtint_t(1)), // cj + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // cexval + ::testing::Values(dcomplex{-2.2, 3.3}), // alpha + ::testing::Values(dcomplex{1.2, -2.3}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::ZGemmEVMatPrint() + ); + +// Exception value testing(on alpha and beta) +// Alpha and beta are set to exception values +INSTANTIATE_TEST_SUITE_P( + bli_zgemm_4x4_avx2_k1_nn_evt_alphabeta, + ZGemmEVTTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // m + ::testing::Values(gtint_t(2), gtint_t(3), gtint_t(4)), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(gtint_t(0)), // ai + ::testing::Values(gtint_t(0)), // aj + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // bi + ::testing::Values(gtint_t(0)), // bj + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(gtint_t(0)), // ci + ::testing::Values(gtint_t(0)), // cj + ::testing::Values(dcomplex{0.0, 0.0}), + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // alpha + ::testing::Values(dcomplex{NaN, 2.3}, dcomplex{Inf, 0.0}, + dcomplex{3.4, NaN}, dcomplex{NaN, -Inf}), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::ZGemmEVAlphaBetaPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp b/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp index 0f4bb4783d..6bdb2d63e8 100644 --- a/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp +++ b/gtestsuite/testsuite/level3/gemm/zgemm_generic.cpp @@ -35,7 +35,7 @@ #include #include "test_gemm.h" -class ZGemmTest : +class ZGemmAccTest : public ::testing::TestWithParam> {}; + gtint_t>> {}; -TEST_P(ZGemmTest, RandomData) { +TEST_P(ZGemmAccTest, Unit_Tester) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(ZGemmTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*n*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(ZGemmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemm(storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemm( storage, transa, transb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } -class ZGemmTestPrint { +class ZGemmAccPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -104,37 +102,62 @@ class ZGemmTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "zgemm_"; #elif TEST_CBLAS std::string str_name = "cblas_zgemm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zgemm"; + std::string str_name = "bli_zgemm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + tsa + tsb; str_name = str_name + "_" + std::to_string(m); str_name = str_name + "_" + std::to_string(n); str_name = str_name + "_" + std::to_string(k); - std::string alpha_str = ( alpha.real > 0) ? std::to_string(int(alpha.real)) : ("m" + std::to_string(int(std::abs(alpha.real)))); - alpha_str = alpha_str + "pi" + (( alpha.imag > 0) ? std::to_string(int(alpha.imag)) : ("m" + std::to_string(int(std::abs(alpha.imag))))); - std::string beta_str = ( beta.real > 0) ? std::to_string(int(beta.real)) : ("m" + std::to_string(int(std::abs(beta.real)))); - beta_str = beta_str + "pi" + (( beta.imag > 0) ? std::to_string(int(beta.imag)) : ("m" + std::to_string(int(std::abs(beta.imag))))); - str_name = str_name + "_a" + alpha_str; - str_name = str_name + "_b" + beta_str; + str_name = str_name + "_a" + testinghelpers::get_value_string(alpha);; + str_name = str_name + "_b" + testinghelpers::get_value_string(beta);; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; +// Unit testing for bli_zgemm_4x4_avx2_k1_nn kernel +/* From the BLAS layer(post parameter checking), the inputs will be redirected to this kernel + if m != 1, n !=1 and k == 1 */ + +INSTANTIATE_TEST_SUITE_P( + bli_zgemm_4x4_avx2_k1_nn, + ZGemmAccTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Range(gtint_t(2), gtint_t(8), 1), // m + ::testing::Range(gtint_t(2), gtint_t(8), 1), // n + ::testing::Values(gtint_t(1)), // k + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // alpha + ::testing::Values(dcomplex{1.0, 0.0}, dcomplex{-1.0, 0.0}, + dcomplex{0.0, 1.0}, dcomplex{2.1, -1.9}, + dcomplex{0.0, 0.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c + ), + ::ZGemmAccPrint() + ); + // Black box testing. INSTANTIATE_TEST_SUITE_P( Blackbox, - ZGemmTest, + ZGemmAccTest, ::testing::Combine( ::testing::Values('c' #ifndef TEST_BLAS @@ -150,8 +173,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0,2.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), - ::ZGemmTestPrint() + ::ZGemmAccPrint() ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp new file mode 100644 index 0000000000..a648f53bc1 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/dgemm_compute_generic.cpp @@ -0,0 +1,212 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" + +class DGemmComputeTest : + public ::testing::TestWithParam> {}; + +TEST_P(DGemmComputeTest, RandomData) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t + char transb = std::get<2>(GetParam()); + // denotes whether matrix a is packed (p) or unpacked (u) + char packa = std::get<3>(GetParam()); + // denotes whether matrix b is packed (p) or unpacked (u) + char packb = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // matrix size k + gtint_t k = std::get<7>(GetParam()); + // specifies alpha value + T alpha = std::get<8>(GetParam()); + // specifies beta value + T beta = std::get<9>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<10>(GetParam()); + gtint_t ldb_inc = std::get<11>(GetParam()); + gtint_t ldc_inc = std::get<12>(GetParam()); + + // Set the threshold for the errors: + double intermediate = (double)m*n*k; + double thresh = 10*intermediate*testinghelpers::getEpsilon(); + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +class DGemmComputeTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + char pka = std::get<3>(str.param); + char pkb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + double alpha = std::get<8>(str.param); + double beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); +#ifdef TEST_BLAS + std::string str_name = "dgemm_compute_"; +#elif TEST_CBLAS + std::string str_name = "cblas_dgemm_compute"; +#else //#elif TEST_BLIS_TYPED + // BLIS interface not yet implemented for pack and compute APIs. + std::string str_name = "blis_dgemm_compute"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + pka + pkb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); + str_name = str_name + "_a" + alpha_str; + std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); + str_name = str_name + "_b" + beta_str; + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + DGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::DGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + TinySizes, + DGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(1), gtint_t(3), 1), // m + ::testing::Range(gtint_t(1), gtint_t(3), 1), // n + ::testing::Range(gtint_t(1), gtint_t(3), 1), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::DGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes + DGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Values(71, 73), // m (MC - 1, MC + 1) + ::testing::Values(4079, 4081), // n (NC - 1, NC + 1) + ::testing::Values(255, 257), // k (KC - 1, KC + 1) + ::testing::Values(1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::DGemmComputeTestPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h new file mode 100644 index 0000000000..1d168df634 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute.h @@ -0,0 +1,456 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "blis.h" +#include "common/testing_helpers.h" + +/** + * @brief Performs the operation: + * C := op( A )*op( B ) + beta*C, + * where op( A ) is one of + * op( A ) = alpha * A or op( A ) = alpha * A**T + * op( A ) = A or op( A ) = A**T + * op( B ) is one of + * op( B ) = alpha * B or op( B ) = alpha * B**T + * op( B ) = B or op( B ) = B**T + * @param[in] transa specifies the form of op( A ) to be used in + the matrix multiplication. + * @param[in] transb specifies the form of op( B ) to be used in + the matrix multiplication. + * @param[in] packa specifies whether to reorder op( A ). + * @param[in] packb specifies whether to reorder op( B ). + * @param[in] m specifies the number of rows of the matrix + op( A ) and of the matrix C. + * @param[in] n specifies the number of columns of the matrix + op( B ) and the number of columns of the matrix C. + * @param[in] k specifies the number of columns of the matrix + op( A ) and the number of rows of the matrix op( B ). + * @param[in] ap specifies pointer which points to the first element of ap. + * @param[in] lda specifies the leading dimension of ap. + * @param[in] bp specifies pointer which points to the first element of bp. + * @param[in] ldb specifies the leading dimension of bp. + * @param[in] beta specifies the scalar beta. + * @param[in,out] cp specifies pointer which points to the first element of cp. + * @param[in] ldc specifies the leading dimension of cp. + */ + +template +static void gemm_compute_(char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ + T unit_alpha = 1.0; + err_t err = BLIS_SUCCESS; + if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = sgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + float* aBuffer = (float*) bli_malloc_user( bufSizeA, &err ); + sgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + sgemm_compute_( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = sgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + float* bBuffer = (float*) bli_malloc_user( bufSizeB, &err ); + sgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + sgemm_compute_( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + sgemm_compute_( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else if constexpr (std::is_same::value) + { + if ( ( packa == 'P' || packa == 'p' ) && ( packb == 'P' || packb == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + &unit_alpha, + ap, + &lda, + aBuffer ); + + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_( &packa, &packb, &m, &n, &k, aBuffer, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( ( packa == 'P' || packa == 'p' ) ) + { + // Reorder A + char identifierA = 'A'; + gtint_t bufSizeA = dgemm_pack_get_size_( &identifierA, + &m, + &n, + &k ); + + double* aBuffer = (double*) bli_malloc_user( bufSizeA, &err ); + dgemm_pack_( &identifierA, + &transa, + &m, + &n, + &k, + alpha, + ap, + &lda, + aBuffer ); + + dgemm_compute_( &packa, &transb, &m, &n, &k, aBuffer, &lda, bp, &ldb, beta, cp, &ldc ); + bli_free_user( aBuffer ); + } + else if ( ( packb == 'P' || packb == 'p' ) ) + { + // Reorder B + char identifierB = 'B'; + gtint_t bufSizeB = dgemm_pack_get_size_( &identifierB, + &m, + &n, + &k ); + + double* bBuffer = (double*) bli_malloc_user( bufSizeB, &err ); + dgemm_pack_( &identifierB, + &transb, + &m, + &n, + &k, + alpha, + bp, + &ldb, + bBuffer ); + + dgemm_compute_( &transa, &packb, &m, &n, &k, ap, &lda, bBuffer, &ldb, beta, cp, &ldc ); + bli_free_user( bBuffer ); + } + else + { + dgemm_compute_( &transa, &transb, &m, &n, &k, ap, &lda, bp, &ldb, beta, cp, &ldc ); + } + } + else + throw std::runtime_error("Error in testsuite/level3/gemm.h: Invalid typename in gemm_compute_()."); +} + +template +static void cblas_gemm_compute(char storage, char transa, char transb, char pcka, char pckb, + gtint_t m, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, + T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc) +{ + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + testinghelpers::char_to_cblas_order( storage, &cblas_order ); + testinghelpers::char_to_cblas_trans( transa, &cblas_transa ); + testinghelpers::char_to_cblas_trans( transb, &cblas_transb ); + + T unit_alpha = 1.0; + CBLAS_IDENTIFIER cblas_identifierA = CblasAMatrix; + CBLAS_IDENTIFIER cblas_identifierB = CblasBMatrix; + CBLAS_STORAGE cblas_packed = CblasPacked; + + err_t err = BLIS_SUCCESS; + + if constexpr (std::is_same::value) + { + if ( ( pcka == 'p' || pcka == 'P' ) && ( pckb == 'p' || pckb == 'P' ) ) + { + gtint_t bufSizeA = cblas_sgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + gtint_t bufSizeB = cblas_sgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, unit_alpha, bp, ldb, bBuffer ); + + cblas_sgemm_compute( cblas_order, cblas_packed, cblas_packed, + m, n, k, aBuffer, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( pcka == 'p' || pcka == 'P' ) + { + gtint_t bufSizeA = cblas_sgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + + cblas_sgemm_compute( cblas_order, cblas_packed, cblas_transb, + m, n, k, aBuffer, lda, bp, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + } + else if ( pckb == 'p' || pckb == 'P' ) + { + gtint_t bufSizeB = cblas_sgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_sgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, *alpha, bp, ldb, bBuffer ); + + cblas_sgemm_compute( cblas_order, cblas_transa, cblas_packed, + m, n, k, ap, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( bBuffer ); + } + else + { + cblas_sgemm_compute( cblas_order, cblas_transa, cblas_transb, + m, n, k, ap, lda, bp, ldb, *beta, cp, ldc ); + } + } + else if constexpr (std::is_same::value) + { + if ( ( pcka == 'p' || pcka == 'P' ) && ( pckb == 'p' || pckb == 'P' ) ) + { + gtint_t bufSizeA = cblas_dgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + gtint_t bufSizeB = cblas_dgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, unit_alpha, bp, ldb, bBuffer ); + + cblas_dgemm_compute( cblas_order, cblas_packed, cblas_packed, + m, n, k, aBuffer, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + bli_free_user( bBuffer ); + } + else if ( pcka == 'p' || pcka == 'P' ) + { + gtint_t bufSizeA = cblas_dgemm_pack_get_size( cblas_identifierA, + m, + n, + k ); + + T* aBuffer = (T*) bli_malloc_user( bufSizeA, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierA, cblas_transa, + m, n, k, *alpha, ap, lda, aBuffer ); + + + cblas_dgemm_compute( cblas_order, cblas_packed, cblas_transb, + m, n, k, aBuffer, lda, bp, ldb, *beta, cp, ldc ); + + bli_free_user( aBuffer ); + } + else if ( pckb == 'p' || pckb == 'P' ) + { + gtint_t bufSizeB = cblas_dgemm_pack_get_size( cblas_identifierB, + m, + n, + k ); + + T* bBuffer = (T*) bli_malloc_user( bufSizeB, &err ); + + cblas_dgemm_pack( cblas_order, cblas_identifierB, cblas_transb, + m, n, k, *alpha, bp, ldb, bBuffer ); + + cblas_dgemm_compute( cblas_order, cblas_transa, cblas_packed, + m, n, k, ap, lda, bBuffer, ldb, *beta, cp, ldc ); + + bli_free_user( bBuffer ); + } + else + { + cblas_dgemm_compute( cblas_order, cblas_transa, cblas_transb, + m, n, k, ap, lda, bp, ldb, *beta, cp, ldc ); + } + } + else + { + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: Invalid typename in cblas_gemm_compute()."); + } +} + +template +static void gemm_compute( char storage, char transa, char transb, char packa, char packb, gtint_t m, gtint_t n, gtint_t k, T* alpha, + T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) +{ +#ifdef TEST_BLAS + if( storage == 'c' || storage == 'C' ) + gemm_compute_( transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); + else + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLAS interface cannot be tested for row-major order."); + +#elif TEST_CBLAS + cblas_gemm_compute( storage, transa, transb, packa, packb, m, n, k, alpha, ap, lda, bp, ldb, beta, cp, ldc ); +#elif TEST_BLIS_TYPED + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: BLIS interfaces not yet implemented for pack and compute BLAS extensions."); +#else + throw std::runtime_error("Error in testsuite/level3/gemm_compute.h: No interfaces are set to be tested."); +#endif +} diff --git a/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp new file mode 100644 index 0000000000..db293c0433 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/gemm_compute_IIT_ERS.cpp @@ -0,0 +1,237 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" +#include "common/wrong_inputs_helpers.h" +#include "common/testing_helpers.h" +#include "inc/check_error.h" + +template +class GEMM_Compute_IIT_ERS_Test : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(GEMM_Compute_IIT_ERS_Test, TypeParam); + +using namespace testinghelpers::IIT; + +#ifdef TEST_BLAS + +/* + Incorrect Input Testing(IIT) + + BLAS exceptions get triggered in the following cases(for GEMM Compute): + 1. When TRANSA != 'N' || TRANSA != 'T' || TRANSA != 'C' || TRANSA != 'P' (info = 1) + 2. When TRANSB != 'N' || TRANSB != 'T' || TRANSB != 'C' || TRANSB != 'P' (info = 2) + 3. When m < 0 (info = 3) + 4. When n < 0 (info = 4) + 5. When k < 0 (info = 5) + 6. When lda < max(1, thresh) (info = 7), thresh set based on TRANSA value + 7. When ldb < max(1, thresh) (info = 9), thresh set based on TRANSB value + 8. When ldc < max(1, n) (info = 12) +*/ + +// When info == 1 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transa) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm_compute( STORAGE, 'x', TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 2 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_transb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for TRANS value for A. + gemm_compute( STORAGE, TRANS, 'x', 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 3 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', -1, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 4 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, -1, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 5 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, k_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, -1, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 7 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_lda) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA - 1, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 9 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldb) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB - 1, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 12 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc_lt_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, -1 ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When info == 12 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, invalid_ldc) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC - 1 ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +/* + Early Return Scenarios(ERS) : + + The GEMM Compute API is expected to return early in the following cases: + + 1. When m == 0. + 2. When n == 0. +*/ + +// When m = 0 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, m_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', 0, N, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} + +// When n = 0 +TYPED_TEST(GEMM_Compute_IIT_ERS_Test, n_eq_zero) +{ + using T = TypeParam; + // Defining the C matrix with values for debugging purposes + std::vector c = testinghelpers::get_random_matrix(-10, 10, STORAGE, 'N', N, N, LDC, 'f'); + + // Copy so that we check that the elements of C are not modified. + std::vector c_ref(c); + // Call BLIS Gemm with a invalid value for m. + gemm_compute( STORAGE, TRANS, TRANS, 'U', 'U', M, 0, K, nullptr, nullptr, LDA, nullptr, LDB, nullptr, nullptr, LDC ); + // Use bitwise comparison (no threshold). + computediff( STORAGE, N, N, c.data(), c_ref.data(), LDC); +} +#endif diff --git a/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp new file mode 100644 index 0000000000..ea574eb723 --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/sgemm_compute_generic.cpp @@ -0,0 +1,214 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_gemm_compute.h" + +class SGemmComputeTest : + public ::testing::TestWithParam> {}; + +TEST_P(SGemmComputeTest, RandomData) +{ +// printf("SGemmCompute_test!!\n"); + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // matrix storage format(row major, column major) + char storage = std::get<0>(GetParam()); + // denotes whether matrix a is n,c,t,h + char transa = std::get<1>(GetParam()); + // denotes whether matrix b is n,c,t,h + char transb = std::get<2>(GetParam()); + // denotes whether matrix a is packed (p) or unpacked (u) + char packa = std::get<3>(GetParam()); + // denotes whether matrix b is packed (p) or unpacked (u) + char packb = std::get<4>(GetParam()); + // matrix size m + gtint_t m = std::get<5>(GetParam()); + // matrix size n + gtint_t n = std::get<6>(GetParam()); + // matrix size k + gtint_t k = std::get<7>(GetParam()); + // specifies alpha value + T alpha = std::get<8>(GetParam()); + // specifies beta value + T beta = std::get<9>(GetParam()); + // lda, ldb, ldc increments. + // If increments are zero, then the array size matches the matrix size. + // If increments are nonnegative, the array size is bigger than the matrix size. + gtint_t lda_inc = std::get<10>(GetParam()); + gtint_t ldb_inc = std::get<11>(GetParam()); + gtint_t ldc_inc = std::get<12>(GetParam()); + + // Set the threshold for the errors: + float intermediate = (float)m*n*k; + float thresh = 10*intermediate*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_gemm_compute( storage, transa, transb, packa, packb, m, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); +} + +class SGemmComputeTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + char sfm = std::get<0>(str.param); + char tsa = std::get<1>(str.param); + char tsb = std::get<2>(str.param); + char pka = std::get<3>(str.param); + char pkb = std::get<4>(str.param); + gtint_t m = std::get<5>(str.param); + gtint_t n = std::get<6>(str.param); + gtint_t k = std::get<7>(str.param); + float alpha = std::get<8>(str.param); + float beta = std::get<9>(str.param); + gtint_t lda_inc = std::get<10>(str.param); + gtint_t ldb_inc = std::get<11>(str.param); + gtint_t ldc_inc = std::get<12>(str.param); +#ifdef TEST_BLAS + std::string str_name = "sgemm_compute_"; +#elif TEST_CBLAS + std::string str_name = "cblas_sgemm_compute"; +#else //#elif TEST_BLIS_TYPED + // BLIS interface not yet implemented for pack and compute APIs. + std::string str_name = "blis_sgemm_compute"; +#endif + str_name = str_name + "_" + sfm+sfm+sfm; + str_name = str_name + "_" + tsa + tsb; + str_name = str_name + "_" + pka + pkb; + str_name = str_name + "_" + std::to_string(m); + str_name = str_name + "_" + std::to_string(n); + str_name = str_name + "_" + std::to_string(k); + std::string alpha_str = ( alpha > 0) ? std::to_string(int(alpha)) : "m" + std::to_string(int(std::abs(alpha))); + str_name = str_name + "_a" + alpha_str; + std::string beta_str = ( beta > 0) ? std::to_string(int(beta)) : "m" + std::to_string(int(std::abs(beta))); + str_name = str_name + "_b" + beta_str; + str_name = str_name + "_" + std::to_string(lda_inc); + str_name = str_name + "_" + std::to_string(ldb_inc); + str_name = str_name + "_" + std::to_string(ldc_inc); + return str_name; + } +}; + +// Black box testing. +INSTANTIATE_TEST_SUITE_P( + Blackbox, + SGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::SGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + TinySizes, + SGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n', 't', 'c'), // transa + ::testing::Values('n', 't', 'c'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Range(gtint_t(1), gtint_t(3), 1), // m + ::testing::Range(gtint_t(1), gtint_t(3), 1), // n + ::testing::Range(gtint_t(1), gtint_t(3), 1), // k + ::testing::Values(0.0, 1.0, -1.2, 2.1), // alpha + ::testing::Values(0.0, 1.0, -1.2, 2.1), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::SGemmComputeTestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + DimensionsGtBlocksizes, // Dimensions > SUP Blocksizes + SGemmComputeTest, + ::testing::Combine( + ::testing::Values('c' +#ifndef TEST_BLAS + ,'r' +#endif + ), // storage format + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values('u', 'p'), // packa + ::testing::Values('u', 'p'), // packb + ::testing::Values(143, 145), // m (MC - 1, MC + 1) + ::testing::Values(8159, 8161), // n (NC - 1, NC + 1) + ::testing::Values(511, 513), // k (KC - 1, KC + 1) + ::testing::Values(1.0), // alpha + ::testing::Values(1.0), // beta + ::testing::Values(gtint_t(0)), // increment to the leading dim of a + ::testing::Values(gtint_t(0)), // increment to the leading dim of b + ::testing::Values(gtint_t(0)) // increment to the leading dim of c + ), + ::SGemmComputeTestPrint() + ); diff --git a/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h new file mode 100644 index 0000000000..a9109d5abc --- /dev/null +++ b/gtestsuite/testsuite/level3/gemm_compute/test_gemm_compute.h @@ -0,0 +1,79 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "gemm_compute.h" +#include "level3/ref_gemm_compute.h" +#include "inc/check_error.h" +#include +#include + +template +void test_gemm_compute( char storage, char trnsa, char trnsb, char pcka, char pckb, + gtint_t m, gtint_t n, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + T alpha, T beta, double thresh ) +{ + // Compute the leading dimensions of a, b, and c. + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); + + //---------------------------------------------------------- + // Initialize matrics with random numbers + //---------------------------------------------------------- + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); + + // Create a copy of c so that we can check reference results. + std::vector c_ref(c); + + //---------------------------------------------------------- + // Call BLIS function + //---------------------------------------------------------- + gemm_compute( storage, trnsa, trnsb, pcka, pckb, m, n, k, &alpha, a.data(), lda, + b.data(), ldb, &beta, c.data(), ldc ); + + //---------------------------------------------------------- + // Call reference implementation. + //---------------------------------------------------------- + testinghelpers::ref_gemm_compute( storage, trnsa, trnsb, pcka, pckb, m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); + + //---------------------------------------------------------- + // check component-wise error. + //---------------------------------------------------------- + computediff( storage, m, n, c.data(), c_ref.data(), ldc, thresh ); +} diff --git a/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp index f15fc50619..07aed996bb 100644 --- a/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/cgemmt_generic.cpp @@ -46,12 +46,12 @@ class cgemmtTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(cgemmtTest); -TEST_P(cgemmtTest, RandomData) { +TEST_P(cgemmtTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -79,8 +79,6 @@ TEST_P(cgemmtTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*k*testinghelpers::getEpsilon(); @@ -88,13 +86,13 @@ TEST_P(cgemmtTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemmt(storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class cgemmtTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -106,7 +104,6 @@ class cgemmtTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "cgemmt_"; #elif TEST_CBLAS @@ -128,10 +125,10 @@ class cgemmtTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; + // Disable tests for BLIS_TYPED case due to compiler errors. #ifndef TEST_BLIS_TYPED // Black box testing. @@ -153,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{1.0,2.0}), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), ::cgemmtTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp index b27b6c66b9..c31260def4 100644 --- a/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/dgemmt_generic.cpp @@ -46,12 +46,12 @@ class dgemmtTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dgemmtTest); -TEST_P(dgemmtTest, RandomData) { +TEST_P(dgemmtTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -79,8 +79,6 @@ TEST_P(dgemmtTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*k*testinghelpers::getEpsilon(); @@ -88,13 +86,13 @@ TEST_P(dgemmtTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemmt(storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class dgemmtTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -106,7 +104,6 @@ class dgemmtTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "dgemmt_"; #elif TEST_CBLAS @@ -126,7 +123,6 @@ class dgemmtTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -151,8 +147,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(3.0), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), ::dgemmtTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemmt/gemmt.h b/gtestsuite/testsuite/level3/gemmt/gemmt.h index 217cd5bcd0..a9a92821e0 100644 --- a/gtestsuite/testsuite/level3/gemmt/gemmt.h +++ b/gtestsuite/testsuite/level3/gemmt/gemmt.h @@ -154,6 +154,7 @@ static void typed_gemmt(char storage, char uplo, char trnsa, char trnsb, throw std::runtime_error("Error in testsuite/level3/gemmt.h: Invalid typename in typed_gemmt()."); } #endif + template static void gemmt( char storage, char uplo, char transa, char transb, gtint_t n, gtint_t k, T* alpha, T* ap, gtint_t lda, T* bp, gtint_t ldb, T* beta, T* cp, gtint_t ldc ) diff --git a/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp index c9686e84bb..e067a684e7 100644 --- a/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/sgemmt_generic.cpp @@ -46,12 +46,12 @@ class sgemmtTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(sgemmtTest); -TEST_P(sgemmtTest, RandomData) { +TEST_P(sgemmtTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -79,8 +79,6 @@ TEST_P(sgemmtTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*n*k*testinghelpers::getEpsilon(); @@ -88,13 +86,13 @@ TEST_P(sgemmtTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemmt(storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class sgemmtTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char tsa = std::get<1>(str.param); char tsb = std::get<2>(str.param); @@ -106,7 +104,6 @@ class sgemmtTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "sgemmt_"; #elif TEST_CBLAS @@ -126,7 +123,6 @@ class sgemmtTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -152,8 +148,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(3.0), // beta ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), ::sgemmtTestPrint() ); diff --git a/gtestsuite/testsuite/level3/gemmt/test_gemmt.h b/gtestsuite/testsuite/level3/gemmt/test_gemmt.h index 9087c9fa81..2afaba222d 100644 --- a/gtestsuite/testsuite/level3/gemmt/test_gemmt.h +++ b/gtestsuite/testsuite/level3/gemmt/test_gemmt.h @@ -42,20 +42,20 @@ template void test_gemmt( char storage, char uplo, char trnsa, char trnsb, gtint_t n, - gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, T beta, double thresh, char datatype ) { - + gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, + T beta, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, trnsa, n, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, trnsb, k, n, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', n, n, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, trnsa, n, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, trnsb, k, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', n, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random numbers //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, trnsa, n, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, trnsb, k, n, ldb, datatype); - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, 'n', n, n, ldc, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, trnsa, n, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, trnsb, k, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', n, n, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -69,7 +69,7 @@ void test_gemmt( char storage, char uplo, char trnsa, char trnsb, gtint_t n, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_gemmt( storage, uplo, trnsa, trnsb, n, k, alpha, + testinghelpers::ref_gemmt( storage, uplo, trnsa, trnsb, n, k, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp b/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp index d5ddd84276..7c8a4c8ecf 100644 --- a/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp +++ b/gtestsuite/testsuite/level3/gemmt/zgemmt_generic.cpp @@ -46,12 +46,12 @@ class zgemmtTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(zgemmtTest); -TEST_P(zgemmtTest, RandomData) { +TEST_P(zgemmtTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -79,22 +79,20 @@ TEST_P(zgemmtTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: - double thresh = std::max(n,k)*testinghelpers::getEpsilon(); + double thresh = (std::max)(n,k)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_gemmt(storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_gemmt( storage, uplo, transa, transb, n, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class zgemmtTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -106,7 +104,6 @@ class zgemmtTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "zgemmt_"; #elif TEST_CBLAS @@ -128,7 +125,6 @@ class zgemmtTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -154,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{1.0,2.0}), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), ::zgemmtTestPrint() ); diff --git a/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp b/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp index 4a1221c4b4..173aa8777b 100644 --- a/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp +++ b/gtestsuite/testsuite/level3/hemm/chemm_generic.cpp @@ -47,10 +47,10 @@ class chemmTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(chemmTest, RandomData) { +TEST_P(chemmTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(chemmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(chemmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_hemm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_hemm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class chemmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class chemmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "chemm_"; #elif TEST_CBLAS std::string str_name = "cblas_chemm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_chemm"; + std::string str_name = "bli_chemm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -130,7 +127,6 @@ class chemmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -155,8 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{-3.0, 2.0}), // beta ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), ::chemmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/hemm/test_hemm.h b/gtestsuite/testsuite/level3/hemm/test_hemm.h index bae4756f6b..a55510bf04 100644 --- a/gtestsuite/testsuite/level3/hemm/test_hemm.h +++ b/gtestsuite/testsuite/level3/hemm/test_hemm.h @@ -42,17 +42,15 @@ template void test_hemm( char storage, char side, char uplo, char conja, char transb, - gtint_t m, gtint_t n, - gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, T beta, - double thresh, char datatype -) { + gtint_t m, gtint_t n, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + T alpha, T beta, double thresh ) +{ // Set the dimension for row/col of A, depending on the value of side. gtint_t k = ((side == 'l')||(side == 'L'))? m : n; // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, 'n', k, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, 'n', k, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. @@ -60,9 +58,9 @@ void test_hemm( char storage, char side, char uplo, char conja, char transb, // Since matrix A, stored in a, is symmetric and we only use the upper or lower // part in the computation of hemm and zero-out the rest to ensure // that code operates as expected. - std::vector a = testinghelpers::get_random_matrix(-5, 2, storage, uplo, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, transb, m, n, ldb, datatype); - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, 'n', m, n, ldc, datatype); + std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, uplo, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -75,7 +73,7 @@ void test_hemm( char storage, char side, char uplo, char conja, char transb, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_hemm( storage, side, uplo, conja, transb, m, n, alpha, + testinghelpers::ref_hemm( storage, side, uplo, conja, transb, m, n, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp b/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp index 4ebc75ef2c..f509cb8881 100644 --- a/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp +++ b/gtestsuite/testsuite/level3/hemm/zhemm_generic.cpp @@ -47,10 +47,10 @@ class zhemmTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zhemmTest, RandomData) { +TEST_P(zhemmTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(zhemmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(zhemmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_hemm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_hemm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class zhemmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class zhemmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "zhemm_"; #elif TEST_CBLAS std::string str_name = "cblas_zhemm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zhemm"; + std::string str_name = "bli_zhemm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -130,7 +127,6 @@ class zhemmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -155,8 +151,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(6)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(6)) // increment to the leading dim of c ), ::zhemmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp b/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp index b33db3a187..b87a833950 100644 --- a/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp +++ b/gtestsuite/testsuite/level3/her2k/cher2k_generic.cpp @@ -46,10 +46,10 @@ class cher2kTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cher2kTest, RandomData) { +TEST_P(cher2kTest, RandomData) +{ using T = scomplex; using RT = typename testinghelpers::type_info::real_type; //---------------------------------------------------------- @@ -78,8 +78,6 @@ TEST_P(cher2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 2*m*k*testinghelpers::getEpsilon(); @@ -87,13 +85,13 @@ TEST_P(cher2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_her2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class cher2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -105,13 +103,12 @@ class cher2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "cher2k_"; #elif TEST_CBLAS std::string str_name = "cblas_cher2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cher2k"; + std::string str_name = "bli_cher2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -126,7 +123,6 @@ class cher2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -150,8 +146,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-3.0, 2.0), // beta ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), ::cher2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/her2k/test_her2k.h b/gtestsuite/testsuite/level3/her2k/test_her2k.h index 60c1f1c2f0..18ab391cd7 100644 --- a/gtestsuite/testsuite/level3/her2k/test_her2k.h +++ b/gtestsuite/testsuite/level3/her2k/test_her2k.h @@ -42,25 +42,23 @@ template::real_type> void test_her2k( char storage, char uplo, char transa, char transb, - gtint_t m, gtint_t k, - gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, RT beta, - double thresh, char datatype -) { + gtint_t m, gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + T alpha, RT beta, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, m, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, transb, m, k, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, m, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random numbers //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, transa, m, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, transb, m, k, ldb, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, k, ldb ); // Since matrix C, stored in c, is symmetric and we only use the upper or lower // part in the computation of her2k and zero-out the rest to ensure // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc, datatype); + std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -74,7 +72,7 @@ void test_her2k( char storage, char uplo, char transa, char transb, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_her2k( storage, uplo, transa, transb, m, k, &alpha, + testinghelpers::ref_her2k( storage, uplo, transa, transb, m, k, &alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp b/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp index 95301a291b..2ae305c086 100644 --- a/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp +++ b/gtestsuite/testsuite/level3/her2k/zher2k_generic.cpp @@ -46,10 +46,10 @@ class zher2kTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zher2kTest, RandomData) { +TEST_P(zher2kTest, RandomData) +{ using T = dcomplex; using RT = typename testinghelpers::type_info::real_type; //---------------------------------------------------------- @@ -78,8 +78,6 @@ TEST_P(zher2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 2*m*k*testinghelpers::getEpsilon(); @@ -87,13 +85,13 @@ TEST_P(zher2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_her2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_her2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class zher2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -105,13 +103,12 @@ class zher2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "zher2k_"; #elif TEST_CBLAS std::string str_name = "cblas_zher2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zher2k"; + std::string str_name = "bli_zher2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -126,7 +123,6 @@ class zher2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -150,8 +146,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(4.0, -1.0), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), ::zher2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/herk/cherk_generic.cpp b/gtestsuite/testsuite/level3/herk/cherk_generic.cpp index 13252de9cd..868b637d3a 100644 --- a/gtestsuite/testsuite/level3/herk/cherk_generic.cpp +++ b/gtestsuite/testsuite/level3/herk/cherk_generic.cpp @@ -44,10 +44,10 @@ class cherkTest : float, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(cherkTest, RandomData) { +TEST_P(cherkTest, RandomData) +{ using T = scomplex; using RT = typename testinghelpers::type_info::real_type; //---------------------------------------------------------- @@ -73,8 +73,6 @@ TEST_P(cherkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -82,13 +80,13 @@ TEST_P(cherkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_herk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_herk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class cherkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -98,13 +96,12 @@ class cherkTestPrint { float beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "cherk_"; #elif TEST_CBLAS std::string str_name = "cblas_cherk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_cherk"; + std::string str_name = "bli_cherk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -117,7 +114,6 @@ class cherkTestPrint { str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -139,8 +135,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-2.0, 3.0), // alpha ::testing::Values(4.0, -1.0), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), ::cherkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/herk/test_herk.h b/gtestsuite/testsuite/level3/herk/test_herk.h index 355b514ec4..a283366566 100644 --- a/gtestsuite/testsuite/level3/herk/test_herk.h +++ b/gtestsuite/testsuite/level3/herk/test_herk.h @@ -41,25 +41,22 @@ #include template::real_type> -void test_herk( char storage, char uplo, char transa, - gtint_t m, gtint_t k, - gtint_t lda_inc, gtint_t ldc_inc, - RT alpha, RT beta, - double thresh, char datatype -) { +void test_herk( char storage, char uplo, char transa, gtint_t m, gtint_t k, + gtint_t lda_inc, gtint_t ldc_inc, RT alpha, RT beta, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, m, k, lda_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, m, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-5, 2, storage, transa, m, k, lda, datatype); + std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, transa, m, k, lda ); // Since matrix C, stored in c, is symmetric, we only use the upper or lower // part in the computation of herk and zero-out the rest to ensure // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix(-8, 12, storage, uplo, m, ldc, datatype); + std::vector c = testinghelpers::get_random_matrix( -8, 12, storage, uplo, m, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); diff --git a/gtestsuite/testsuite/level3/herk/zherk_generic.cpp b/gtestsuite/testsuite/level3/herk/zherk_generic.cpp index 3bbe6cf334..b3d89854c6 100644 --- a/gtestsuite/testsuite/level3/herk/zherk_generic.cpp +++ b/gtestsuite/testsuite/level3/herk/zherk_generic.cpp @@ -44,10 +44,10 @@ class zherkTest : double, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zherkTest, RandomData) { +TEST_P(zherkTest, RandomData) +{ using T = dcomplex; using RT = typename testinghelpers::type_info::real_type; //---------------------------------------------------------- @@ -73,8 +73,6 @@ TEST_P(zherkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -82,13 +80,13 @@ TEST_P(zherkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_herk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_herk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class zherkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -98,13 +96,12 @@ class zherkTestPrint { double beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "zherk_"; #elif TEST_CBLAS std::string str_name = "cblas_zherk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zherk"; + std::string str_name = "bli_zherk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -117,7 +114,6 @@ class zherkTestPrint { str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -131,16 +127,15 @@ INSTANTIATE_TEST_SUITE_P( #ifndef TEST_BLAS ,'r' #endif - ), - ::testing::Values('u','l'), // storage format - ::testing::Values('n','c'), // u:upper, l:lower - ::testing::Range(gtint_t(10), gtint_t(31), 10), // transa + ), // storage format + ::testing::Values('u','l'), // u:upper, l:lower + ::testing::Values('n','c'), // transa ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Values(2.0, -1.0), // n - ::testing::Values(-3.0, 2.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(4)), // beta - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // increment to the leading dim of b - ), // i : integer, f : float datatype type tested + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Values(2.0, -1.0), // alpha + ::testing::Values(-3.0, 2.0), // beta + ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c + ), ::zherkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/csymm_generic.cpp b/gtestsuite/testsuite/level3/symm/csymm_generic.cpp index 96c53c63df..72e84c9069 100644 --- a/gtestsuite/testsuite/level3/symm/csymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/csymm_generic.cpp @@ -47,10 +47,10 @@ class csymmTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(csymmTest, RandomData) { +TEST_P(csymmTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(csymmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(csymmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class csymmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class csymmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "csymm_"; #elif TEST_CBLAS std::string str_name = "cblas_csymm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_csymm"; + std::string str_name = "bli_csymm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -131,7 +128,6 @@ class csymmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -156,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of c ), ::csymmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp b/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp index 9217152a22..34d4fdb474 100644 --- a/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/dsymm_generic.cpp @@ -47,10 +47,10 @@ class dsymmTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsymmTest, RandomData) { +TEST_P(dsymmTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(dsymmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = 30*m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(dsymmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class dsymmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class dsymmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "dsymm_"; #elif TEST_CBLAS std::string str_name = "cblas_dsymm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsymm"; + std::string str_name = "bli_dsymm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -129,7 +126,6 @@ class dsymmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -154,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(6)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c ), ::dsymmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp b/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp index 1fca984ee7..749b7a7fce 100644 --- a/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/ssymm_generic.cpp @@ -47,10 +47,10 @@ class ssymmTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssymmTest, RandomData) { +TEST_P(ssymmTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(ssymmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = 8*m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(ssymmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class ssymmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class ssymmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "ssymm_"; #elif TEST_CBLAS std::string str_name = "cblas_ssymm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssymm"; + std::string str_name = "bli_ssymm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -129,7 +126,6 @@ class ssymmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -154,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), ::ssymmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/symm/test_symm.h b/gtestsuite/testsuite/level3/symm/test_symm.h index 4274067b72..cc90d7f52a 100644 --- a/gtestsuite/testsuite/level3/symm/test_symm.h +++ b/gtestsuite/testsuite/level3/symm/test_symm.h @@ -42,18 +42,15 @@ template void test_symm( char storage, char side, char uplo, char conja, char transb, - gtint_t m, gtint_t n, - gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, T beta, - double thresh, char datatype -) { - + gtint_t m, gtint_t n, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, + T alpha, T beta, double thresh ) +{ // Set the dimension for row/col of A, depending on the value of side. gtint_t k = ((side == 'l')||(side == 'L'))? m : n; // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, conja, k, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, transb, m, n, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, conja, k, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. @@ -61,9 +58,9 @@ void test_symm( char storage, char side, char uplo, char conja, char transb, // Since matrix A, stored in a, is symmetric and we only use the upper or lower // part in the computation of hemm and zero-out the rest to ensure // that code operates as expected. - std::vector a = testinghelpers::get_random_matrix(-5, 2, storage, uplo, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, transb, m, n, ldb, datatype); - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, 'n', m, n, ldc, datatype); + std::vector a = testinghelpers::get_random_matrix( -5, 2, storage, uplo, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -77,7 +74,7 @@ void test_symm( char storage, char side, char uplo, char conja, char transb, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_symm( storage, side, uplo, conja, transb, m, n, alpha, + testinghelpers::ref_symm( storage, side, uplo, conja, transb, m, n, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp b/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp index 9585a8915b..a6c163816a 100644 --- a/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp +++ b/gtestsuite/testsuite/level3/symm/zsymm_generic.cpp @@ -47,10 +47,10 @@ class zsymmTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zsymmTest, RandomData) { +TEST_P(zsymmTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -81,8 +81,6 @@ TEST_P(zsymmTest, RandomData) { gtint_t lda_inc = std::get<9>(GetParam()); gtint_t ldb_inc = std::get<10>(GetParam()); gtint_t ldc_inc = std::get<11>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<12>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -90,13 +88,13 @@ TEST_P(zsymmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_symm(storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_symm( storage, side, uplo, conja, transb, m, n, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class zsymmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uplo = std::get<2>(str.param); @@ -109,13 +107,12 @@ class zsymmTestPrint { gtint_t lda_inc = std::get<9>(str.param); gtint_t ldb_inc = std::get<10>(str.param); gtint_t ldc_inc = std::get<11>(str.param); - char datatype = std::get<12>(str.param); #ifdef TEST_BLAS std::string str_name = "zsymm_"; #elif TEST_CBLAS std::string str_name = "cblas_zsymm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zsymm"; + std::string str_name = "bli_zsymm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uplo; @@ -131,7 +128,6 @@ class zsymmTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -156,8 +152,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{-3.0, 2.0}, dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), ::zsymmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp index 6b359496a3..2ee7903302 100644 --- a/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/csyr2k_generic.cpp @@ -46,10 +46,10 @@ class csyr2kTest : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(csyr2kTest, RandomData) { +TEST_P(csyr2kTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(csyr2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(csyr2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class csyr2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -104,13 +102,12 @@ class csyr2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "csyr2k_"; #elif TEST_CBLAS std::string str_name = "cblas_csyr2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_csyr2k"; + std::string str_name = "bli_csyr2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -126,7 +123,6 @@ class csyr2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,18 +136,17 @@ INSTANTIATE_TEST_SUITE_P( #ifndef TEST_BLAS ,'r' #endif - ), - ::testing::Values('u','l'), // storage format - ::testing::Values('n'), // u:upper, l:lower + ), // storage format + ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // transb + ::testing::Values('n'), // transb ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // n - ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // alpha - ::testing::Values(gtint_t(0), gtint_t(2)), // beta - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // increment to the leading dim of c - ), // i : integer, f : dcomplex datatype type tested + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha + ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // beta + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c + ), ::csyr2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp index 39110773f3..f990ef6ac3 100644 --- a/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/dsyr2k_generic.cpp @@ -46,10 +46,10 @@ class dsyr2kTest : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsyr2kTest, RandomData) { +TEST_P(dsyr2kTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(dsyr2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(dsyr2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class dsyr2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -104,13 +102,12 @@ class dsyr2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "dsyr2k_"; #elif TEST_CBLAS std::string str_name = "cblas_dsyr2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsyr2k"; + std::string str_name = "bli_dsyr2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -124,7 +121,6 @@ class dsyr2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,18 +134,17 @@ INSTANTIATE_TEST_SUITE_P( #ifndef TEST_BLAS ,'r' #endif - ), - ::testing::Values('u','l'), // storage format - ::testing::Values('n'), // u:upper, l:lower + ), // storage format + ::testing::Values('u','l'), // u:upper, l:lower ::testing::Values('n'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // transb + ::testing::Values('n'), // transb ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Values( 1.0, -2.0), // n - ::testing::Values(-1.0, 1.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(4)), // beta - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // increment to the leading dim of c - ), // i : integer, f : dcomplex datatype type tested + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Values( 1.0, -2.0), // alpha + ::testing::Values(-1.0, 1.0), // beta + ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(7)) // increment to the leading dim of c + ), ::dsyr2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp index ad6f883606..4b4cc8ccdd 100644 --- a/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/ssyr2k_generic.cpp @@ -46,10 +46,10 @@ class ssyr2kTest : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssyr2kTest, RandomData) { +TEST_P(ssyr2kTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(ssyr2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = 10*m*k*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(ssyr2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class ssyr2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -104,13 +102,12 @@ class ssyr2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "ssyr2k_"; #elif TEST_CBLAS std::string str_name = "cblas_ssyr2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssyr2k"; + std::string str_name = "bli_ssyr2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -124,7 +121,6 @@ class ssyr2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,18 +134,17 @@ INSTANTIATE_TEST_SUITE_P( #ifndef TEST_BLAS ,'r' #endif - ), - ::testing::Values('u','l'), // storage format - ::testing::Values('n'), // u:upper, l:lower - ::testing::Values('n'), // transa - ::testing::Range(gtint_t(10), gtint_t(31), 10), // transb - ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Values( 1.0, -2.0), // n - ::testing::Values(-1.0, 1.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(7)), // beta - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // increment to the leading dim of c - ), // i : integer, f : dcomplex datatype type tested + ), // storage format + ::testing::Values('u','l'), // u:upper, l:lower + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Range(gtint_t(10), gtint_t(31), 10), // m + ::testing::Range(gtint_t(10), gtint_t(31), 10), // n + ::testing::Values( 1.0, -2.0), // alpha + ::testing::Values(-1.0, 1.0), // beta + ::testing::Values(gtint_t(0), gtint_t(7)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of b + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of c + ), ::ssyr2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syr2k/test_syr2k.h b/gtestsuite/testsuite/level3/syr2k/test_syr2k.h index 9a7fb82b6f..da2dabb0a9 100644 --- a/gtestsuite/testsuite/level3/syr2k/test_syr2k.h +++ b/gtestsuite/testsuite/level3/syr2k/test_syr2k.h @@ -41,26 +41,24 @@ #include template -void test_syr2k( char storage, char uplo, char transa, char transb, - gtint_t m, gtint_t k, - gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, - T alpha, T beta, - double thresh, char datatype -) { +void test_syr2k( char storage, char uplo, char transa, char transb, gtint_t m, + gtint_t k, gtint_t lda_inc, gtint_t ldb_inc, gtint_t ldc_inc, T alpha, + T beta, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, m, k, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, transb, m, k, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, m, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, k, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, transa, m, k, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, transb, m, k, ldb, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, k, ldb ); // Since matrix C, stored in c, is symmetric and we only use the upper or lower // part in the computation of her2k and zero-out the rest to ensure // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc, datatype); + std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, uplo, m, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); @@ -74,7 +72,7 @@ void test_syr2k( char storage, char uplo, char transa, char transb, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_syr2k( storage, uplo, transa, transb, m, k, alpha, + testinghelpers::ref_syr2k( storage, uplo, transa, transb, m, k, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp b/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp index 9b0d018768..3600872367 100644 --- a/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp +++ b/gtestsuite/testsuite/level3/syr2k/zsyr2k_generic.cpp @@ -46,10 +46,10 @@ class zsyr2kTest : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zsyr2kTest, RandomData) { +TEST_P(zsyr2kTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -77,8 +77,6 @@ TEST_P(zsyr2kTest, RandomData) { gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); gtint_t ldc_inc = std::get<10>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<11>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -86,13 +84,13 @@ TEST_P(zsyr2kTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syr2k(storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syr2k( storage, uplo, transa, transb, m, k, lda_inc, ldb_inc, ldc_inc, alpha, beta, thresh ); } class zsyr2kTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -104,13 +102,12 @@ class zsyr2kTestPrint { gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); gtint_t ldc_inc = std::get<10>(str.param); - char datatype = std::get<11>(str.param); #ifdef TEST_BLAS std::string str_name = "zsyr2k_"; #elif TEST_CBLAS std::string str_name = "cblas_zsyr2k"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zsyr2k"; + std::string str_name = "bli_zsyr2k"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -126,7 +123,6 @@ class zsyr2kTestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -150,8 +146,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{-3.0, 2.0}, dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of b - ::testing::Values(gtint_t(0), gtint_t(6)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(6)) // increment to the leading dim of c ), ::zsyr2kTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp index 092235019e..c876843931 100644 --- a/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/csyrk_generic.cpp @@ -44,10 +44,10 @@ class csyrkTest : scomplex, scomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(csyrkTest, RandomData) { +TEST_P(csyrkTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -72,8 +72,6 @@ TEST_P(csyrkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -81,13 +79,13 @@ TEST_P(csyrkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class csyrkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -97,13 +95,12 @@ class csyrkTestPrint { scomplex beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "csyrk_"; #elif TEST_CBLAS std::string str_name = "cblas_csyrk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_csyrk"; + std::string str_name = "bli_csyrk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -118,7 +115,6 @@ class csyrkTestPrint { str_name = str_name + "_a" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{2.0, -1.0}, scomplex{-2.0, 3.0}), // alpha ::testing::Values(scomplex{-3.0, 2.0}, scomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(2)) // increment to the leading dim of c ), ::csyrkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp index af5d263e5c..05f1dc0229 100644 --- a/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/dsyrk_generic.cpp @@ -44,10 +44,10 @@ class dsyrkTest : double, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dsyrkTest, RandomData) { +TEST_P(dsyrkTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -72,8 +72,6 @@ TEST_P(dsyrkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -81,13 +79,13 @@ TEST_P(dsyrkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class dsyrkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -97,13 +95,12 @@ class dsyrkTestPrint { double beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "dsyrk_"; #elif TEST_CBLAS std::string str_name = "cblas_dsyrk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dsyrk"; + std::string str_name = "bli_dsyrk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -116,7 +113,6 @@ class dsyrkTestPrint { str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -138,8 +134,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(-1.0, 1.0), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(9)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(9)) // increment to the leading dim of c ), ::dsyrkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp index a413c6f15c..6ce9ab89bf 100644 --- a/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/ssyrk_generic.cpp @@ -44,10 +44,10 @@ class ssyrkTest : float, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ssyrkTest, RandomData) { +TEST_P(ssyrkTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -72,8 +72,6 @@ TEST_P(ssyrkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -81,13 +79,13 @@ TEST_P(ssyrkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class ssyrkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -97,13 +95,12 @@ class ssyrkTestPrint { float beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "ssyrk_"; #elif TEST_CBLAS std::string str_name = "cblas_ssyrk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ssyrk"; + std::string str_name = "bli_ssyrk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -116,7 +113,6 @@ class ssyrkTestPrint { str_name = str_name + "_b" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -130,16 +126,15 @@ INSTANTIATE_TEST_SUITE_P( #ifndef TEST_BLAS ,'r' #endif - ), - ::testing::Values('u','l'), // storage format - ::testing::Values('n','t','c'), // u:upper, l:lower - ::testing::Range(gtint_t(10), gtint_t(31), 10), // transa + ), // storage format + ::testing::Values('u','l'), // u:upper, l:lower + ::testing::Values('n','t','c'), // transa ::testing::Range(gtint_t(10), gtint_t(31), 10), // m - ::testing::Values( 1.0, -2.0), // k - ::testing::Values(-1.0, 1.0), // alpha - ::testing::Values(gtint_t(0), gtint_t(3)), // beta - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of a - ::testing::Values(ELEMENT_TYPE) // increment to the leading dim of c - ), // i : integer, f : dcomplex datatype type tested + ::testing::Range(gtint_t(10), gtint_t(31), 10), // k + ::testing::Values( 1.0, -2.0), // alpha + ::testing::Values(-1.0, 1.0), // beta + ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of c + ), ::ssyrkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/syrk/test_syrk.h b/gtestsuite/testsuite/level3/syrk/test_syrk.h index 9c8585e64a..464f608827 100644 --- a/gtestsuite/testsuite/level3/syrk/test_syrk.h +++ b/gtestsuite/testsuite/level3/syrk/test_syrk.h @@ -41,24 +41,21 @@ #include template -void test_syrk( char storage, char uplo, char transa, - gtint_t m, gtint_t k, - gtint_t lda_inc, gtint_t ldc_inc, - T alpha, T beta, - double thresh, char datatype -) { +void test_syrk( char storage, char uplo, char transa, gtint_t m, gtint_t k, + gtint_t lda_inc, gtint_t ldc_inc, T alpha, T beta, double thresh ) +{ // Compute the leading dimensions of a, b, and c. - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, m, k, lda_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, m, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, m, k, lda_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, m, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random integer numbers. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda, datatype ); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, m, k, lda ); // Since matrix C, stored in c, is symmetric, we only use the upper or lower // part in the computation of syrk and zero-out the rest to ensure // that code operates as expected. - std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, uplo, m, ldc, datatype ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, uplo, m, ldc ); // Create a copy of c so that we can check reference results. std::vector c_ref(c); diff --git a/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp b/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp index 7bb7d9cedf..406d137d43 100644 --- a/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp +++ b/gtestsuite/testsuite/level3/syrk/zsyrk_generic.cpp @@ -44,10 +44,10 @@ class zsyrkTest : dcomplex, dcomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(zsyrkTest, RandomData) { +TEST_P(zsyrkTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -72,8 +72,6 @@ TEST_P(zsyrkTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<7>(GetParam()); gtint_t ldc_inc = std::get<8>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<9>(GetParam()); // Set the threshold for the errors: double thresh = m*k*testinghelpers::getEpsilon(); @@ -81,13 +79,13 @@ TEST_P(zsyrkTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_syrk(storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh, datatype); + test_syrk( storage, uplo, transa, m, k, lda_inc, ldc_inc, alpha, beta, thresh ); } class zsyrkTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char uplo = std::get<1>(str.param); char tsa = std::get<2>(str.param); @@ -97,13 +95,12 @@ class zsyrkTestPrint { dcomplex beta = std::get<6>(str.param); gtint_t lda_inc = std::get<7>(str.param); gtint_t ldc_inc = std::get<8>(str.param); - char datatype = std::get<9>(str.param); #ifdef TEST_BLAS std::string str_name = "zsyrk_"; #elif TEST_CBLAS std::string str_name = "cblas_zsyrk"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_zsyrk"; + std::string str_name = "bli_zsyrk"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + uplo; @@ -118,7 +115,6 @@ class zsyrkTestPrint { str_name = str_name + "_a" + beta_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -140,8 +136,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{2.0, -1.0}, dcomplex{-2.0, 3.0}), // alpha ::testing::Values(dcomplex{-3.0, 2.0}, dcomplex{4.0, -1.0}), // beta ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : dcomplex datatype type tested + ::testing::Values(gtint_t(0), gtint_t(5)) // increment to the leading dim of c ), ::zsyrkTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp index a875f77282..5887027a58 100644 --- a/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/ctrmm_generic.cpp @@ -45,10 +45,10 @@ class ctrmmTest : gtint_t, scomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ctrmmTest, RandomData) { +TEST_P(ctrmmTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,8 +76,6 @@ TEST_P(ctrmmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -85,13 +83,13 @@ TEST_P(ctrmmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class ctrmmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class ctrmmTestPrint { scomplex alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "ctrmm_"; #elif TEST_CBLAS std::string str_name = "cblas_ctrmm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ctrmm"; + std::string str_name = "bli_ctrmm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -120,7 +117,6 @@ class ctrmmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -143,8 +139,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(scomplex{2.0,-1.0}), // alpha ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), ::ctrmmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp index 94fb07ba3c..1c9c251bdf 100644 --- a/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/dtrmm_generic.cpp @@ -45,10 +45,10 @@ class dtrmmTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dtrmmTest, RandomData) { +TEST_P(dtrmmTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,8 +76,6 @@ TEST_P(dtrmmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -85,13 +83,13 @@ TEST_P(dtrmmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class dtrmmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class dtrmmTestPrint { double alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "dtrmm_"; #elif TEST_CBLAS std::string str_name = "cblas_dtrmm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dtrmm"; + std::string str_name = "bli_dtrmm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -119,7 +116,6 @@ class dtrmmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -142,8 +138,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), ::dtrmmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp index df2287c90a..6851e1f52c 100644 --- a/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/strmm_generic.cpp @@ -45,10 +45,10 @@ class strmmTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(strmmTest, RandomData) { +TEST_P(strmmTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,8 +76,6 @@ TEST_P(strmmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = 20*m*n*testinghelpers::getEpsilon(); @@ -85,13 +83,13 @@ TEST_P(strmmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class strmmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class strmmTestPrint { float alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "strmm_"; #elif TEST_CBLAS std::string str_name = "cblas_strmm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_strmm"; + std::string str_name = "bli_strmm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -119,7 +116,6 @@ class strmmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -142,8 +138,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), ::strmmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm/test_trmm.h b/gtestsuite/testsuite/level3/trmm/test_trmm.h index 1993127bae..4ba801d937 100644 --- a/gtestsuite/testsuite/level3/trmm/test_trmm.h +++ b/gtestsuite/testsuite/level3/trmm/test_trmm.h @@ -37,30 +37,28 @@ #include "trmm.h" #include "level3/ref_trmm.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template -void test_trmm( char storage, char side, char uploa, char transa, - char diaga, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, - gtint_t ldb_inc, double thresh, char datatype ) { - +void test_trmm( char storage, char side, char uploa, char transa, char diaga, + gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, double thresh ) +{ gtint_t mn; testinghelpers::set_dim_with_side( side, m, n, &mn ); - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, mn, mn, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldb_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); //---------------------------------------------------------- // Initialize matrics with random values. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, transa, mn, mn, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, 'n', m, n, ldb, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, mn, mn, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, 'n', m, n, ldb ); // Create a copy of v so that we can check reference results. std::vector b_ref(b); - mktrim( storage, uploa, mn, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, mn, a.data(), lda ); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- @@ -69,7 +67,7 @@ void test_trmm( char storage, char side, char uploa, char transa, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_trmm( storage, side, uploa, transa, diaga, m, n, alpha, a.data(), lda, b_ref.data(), ldb ); + testinghelpers::ref_trmm( storage, side, uploa, transa, diaga, m, n, alpha, a.data(), lda, b_ref.data(), ldb ); //---------------------------------------------------------- // check component-wise error. diff --git a/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp b/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp index 823f9fdcf3..d6ad3e02ca 100644 --- a/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm/ztrmm_generic.cpp @@ -45,10 +45,10 @@ class ztrmmTest : gtint_t, dcomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ztrmmTest, RandomData) { +TEST_P(ztrmmTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,8 +76,6 @@ TEST_P(ztrmmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -85,13 +83,13 @@ TEST_P(ztrmmTest, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trmm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class ztrmmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class ztrmmTestPrint { dcomplex alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "ztrmm_"; #elif TEST_CBLAS std::string str_name = "cblas_ztrmm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ztrmm"; + std::string str_name = "bli_ztrmm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -120,7 +117,6 @@ class ztrmmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -143,8 +139,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(dcomplex{1.0,2.0}), // alpha ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(1)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(1)) // increment to the leading dim of b ), ::ztrmmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp index a10d9866ef..839c472988 100644 --- a/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/ctrmm3_generic.cpp @@ -48,12 +48,12 @@ class ctrmm3Test : scomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ctrmm3Test); -TEST_P(ctrmm3Test, RandomData) { +TEST_P(ctrmm3Test, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -86,8 +86,6 @@ TEST_P(ctrmm3Test, RandomData) { gtint_t lda_inc = std::get<10>(GetParam()); gtint_t ldb_inc = std::get<11>(GetParam()); gtint_t ldc_inc = std::get<12>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<13>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -95,13 +93,13 @@ TEST_P(ctrmm3Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh, datatype ); + test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } class ctrmm3TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -115,8 +113,7 @@ class ctrmm3TestPrint { gtint_t lda_inc = std::get<10>(str.param); gtint_t ldb_inc = std::get<11>(str.param); gtint_t ldc_inc = std::get<12>(str.param); - char datatype = std::get<13>(str.param); - std::string str_name = "blis_ctrmm3"; + std::string str_name = "bli_ctrmm3"; str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa + transb; str_name = str_name + "_d" + diaga; @@ -131,7 +128,6 @@ class ctrmm3TestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -154,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(scomplex{-1.0,1.0}), // beta ::testing::Values(gtint_t(0)), // increment to the leading dim of a ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), ::ctrmm3TestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp index 222d70604e..343a573666 100644 --- a/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/dtrmm3_generic.cpp @@ -48,12 +48,12 @@ class dtrmm3Test : double, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(dtrmm3Test); -TEST_P(dtrmm3Test, RandomData) { +TEST_P(dtrmm3Test, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -86,8 +86,6 @@ TEST_P(dtrmm3Test, RandomData) { gtint_t lda_inc = std::get<10>(GetParam()); gtint_t ldb_inc = std::get<11>(GetParam()); gtint_t ldc_inc = std::get<12>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<13>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -95,13 +93,13 @@ TEST_P(dtrmm3Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh, datatype ); + test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } class dtrmm3TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -115,8 +113,7 @@ class dtrmm3TestPrint { gtint_t lda_inc = std::get<10>(str.param); gtint_t ldb_inc = std::get<11>(str.param); gtint_t ldc_inc = std::get<12>(str.param); - char datatype = std::get<13>(str.param); - std::string str_name = "blis_dtrmm3"; + std::string str_name = "bli_dtrmm3"; str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa + transb; str_name = str_name + "_d" + diaga; @@ -129,7 +126,6 @@ class dtrmm3TestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -152,8 +148,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 2.0), // beta ::testing::Values(gtint_t(0)), // increment to the leading dim of a ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), ::dtrmm3TestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp index df6e4e9bee..2d52b620e8 100644 --- a/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/strmm3_generic.cpp @@ -48,12 +48,12 @@ class strmm3Test : float, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(strmm3Test); -TEST_P(strmm3Test, RandomData) { +TEST_P(strmm3Test, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -86,8 +86,6 @@ TEST_P(strmm3Test, RandomData) { gtint_t lda_inc = std::get<10>(GetParam()); gtint_t ldb_inc = std::get<11>(GetParam()); gtint_t ldc_inc = std::get<12>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<13>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -95,13 +93,13 @@ TEST_P(strmm3Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh, datatype ); + test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } class strmm3TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -115,8 +113,7 @@ class strmm3TestPrint { gtint_t lda_inc = std::get<10>(str.param); gtint_t ldb_inc = std::get<11>(str.param); gtint_t ldc_inc = std::get<12>(str.param); - char datatype = std::get<13>(str.param); - std::string str_name = "blis_strmm3"; + std::string str_name = "bli_strmm3"; str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa + transb; str_name = str_name + "_d" + diaga; @@ -129,7 +126,6 @@ class strmm3TestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -152,8 +148,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(-1.0, 2.0), // beta ::testing::Values(gtint_t(0)), // increment to the leading dim of a ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), ::strmm3TestPrint() ); diff --git a/gtestsuite/testsuite/level3/trmm3/test_trmm3.h b/gtestsuite/testsuite/level3/trmm3/test_trmm3.h index 779f2fef50..8203a0cb6b 100644 --- a/gtestsuite/testsuite/level3/trmm3/test_trmm3.h +++ b/gtestsuite/testsuite/level3/trmm3/test_trmm3.h @@ -37,32 +37,31 @@ #include "trmm3.h" #include "level3/ref_trmm3.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template void test_trmm3( char storage, char side, char uploa, char transa, char diaga, - char transb, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, - T beta, gtint_t ldc_inc, double thresh, char datatype ) { - + char transb, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, + T beta, gtint_t ldc_inc, double thresh ) +{ gtint_t mn; testinghelpers::set_dim_with_side( side, m, n, &mn ); - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, mn, mn, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, transb, m, n, ldb_inc); - gtint_t ldc = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldc_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, transb, m, n, ldb_inc ); + gtint_t ldc = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldc_inc ); //---------------------------------------------------------- // Initialize matrics with random values. //---------------------------------------------------------- - std::vector a = testinghelpers::get_random_matrix(-2, 8, storage, transa, mn, mn, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(-5, 2, storage, transb, m, n, ldb, datatype); - std::vector c = testinghelpers::get_random_matrix(-3, 5, storage, 'n', m, n, ldc, datatype); + std::vector a = testinghelpers::get_random_matrix( -2, 8, storage, transa, mn, mn, lda ); + std::vector b = testinghelpers::get_random_matrix( -5, 2, storage, transb, m, n, ldb ); + std::vector c = testinghelpers::get_random_matrix( -3, 5, storage, 'n', m, n, ldc ); // Create a copy of v so that we can check reference results. std::vector c_ref(c); - mktrim( storage, uploa, mn, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, mn, a.data(), lda ); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- @@ -71,7 +70,7 @@ void test_trmm3( char storage, char side, char uploa, char transa, char diaga, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_trmm3( storage, side, uploa, transa, diaga, transb, + testinghelpers::ref_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, a.data(), lda, b.data(), ldb, beta, c_ref.data(), ldc ); //---------------------------------------------------------- diff --git a/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp b/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp index f32c5caab8..6ef3931d72 100644 --- a/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp +++ b/gtestsuite/testsuite/level3/trmm3/ztrmm3_generic.cpp @@ -48,12 +48,12 @@ class ztrmm3Test : dcomplex, gtint_t, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ztrmm3Test); -TEST_P(ztrmm3Test, RandomData) { +TEST_P(ztrmm3Test, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -86,8 +86,6 @@ TEST_P(ztrmm3Test, RandomData) { gtint_t lda_inc = std::get<10>(GetParam()); gtint_t ldb_inc = std::get<11>(GetParam()); gtint_t ldc_inc = std::get<12>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<13>(GetParam()); // Set the threshold for the errors: double thresh = m*n*testinghelpers::getEpsilon(); @@ -95,13 +93,13 @@ TEST_P(ztrmm3Test, RandomData) { //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh, datatype ); + test_trmm3( storage, side, uploa, transa, diaga, transb, m, n, alpha, lda_inc, ldb_inc, beta, ldc_inc, thresh ); } class ztrmm3TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -115,8 +113,7 @@ class ztrmm3TestPrint { gtint_t lda_inc = std::get<10>(str.param); gtint_t ldb_inc = std::get<11>(str.param); gtint_t ldc_inc = std::get<12>(str.param); - char datatype = std::get<13>(str.param); - std::string str_name = "blis_ztrmm3"; + std::string str_name = "bli_ztrmm3"; str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa + transb; str_name = str_name + "_d" + diaga; @@ -131,7 +128,6 @@ class ztrmm3TestPrint { str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); str_name = str_name + "_" + std::to_string(ldc_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -154,8 +150,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(dcomplex{2.0,-1.0}), // beta ::testing::Values(gtint_t(0)), // increment to the leading dim of a ::testing::Values(gtint_t(0)), // increment to the leading dim of b - ::testing::Values(gtint_t(0)), // increment to the leading dim of c - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0)) // increment to the leading dim of c ), ::ztrmm3TestPrint() ); diff --git a/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp index d4644da077..85c3917a39 100644 --- a/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp +++ b/gtestsuite/testsuite/level3/trsm/ctrsm_generic.cpp @@ -45,10 +45,10 @@ class ctrsmTest : gtint_t, scomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ctrsmTest, RandomData) { +TEST_P(ctrsmTest, RandomData) +{ using T = scomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,22 +76,20 @@ TEST_P(ctrsmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = std::max(m, n)*testinghelpers::getEpsilon(); + double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class ctrsmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class ctrsmTestPrint { scomplex alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "ctrsm_"; #elif TEST_CBLAS std::string str_name = "cblas_ctrsm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ctrsm"; + std::string str_name = "bli_ctrsm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -120,7 +117,6 @@ class ctrsmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -143,8 +139,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(31), 10), // n ::testing::Values(scomplex{2.0,-1.0}), // alpha ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), ::ctrsmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp index 9995ca3c6c..87b841defd 100644 --- a/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp +++ b/gtestsuite/testsuite/level3/trsm/dtrsm_generic.cpp @@ -45,10 +45,10 @@ class dtrsmTest : gtint_t, double, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(dtrsmTest, RandomData) { +TEST_P(dtrsmTest, RandomData) +{ using T = double; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,22 +76,20 @@ TEST_P(dtrsmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = std::max(m, n)*testinghelpers::getEpsilon(); + double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class dtrsmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class dtrsmTestPrint { double alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "dtrsm_"; #elif TEST_CBLAS std::string str_name = "cblas_dtrsm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_dtrsm"; + std::string str_name = "bli_dtrsm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -119,7 +116,6 @@ class dtrsmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -142,8 +138,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(11), 10), // n ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(gtint_t(0), gtint_t(5)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), ::dtrsmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp index aa69d719ac..2e197c104f 100644 --- a/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp +++ b/gtestsuite/testsuite/level3/trsm/strsm_generic.cpp @@ -45,10 +45,10 @@ class strsmTest : gtint_t, float, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(strsmTest, RandomData) { +TEST_P(strsmTest, RandomData) +{ using T = float; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,22 +76,20 @@ TEST_P(strsmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = std::max(m, n)*testinghelpers::getEpsilon(); + double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class strsmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class strsmTestPrint { float alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "strsm_"; #elif TEST_CBLAS std::string str_name = "cblas_strsm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_strsm"; + std::string str_name = "bli_strsm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -119,7 +116,6 @@ class strsmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -142,8 +138,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(11), 10), // n ::testing::Values( 1.0, -2.0), // alpha ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(4)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(4)) // increment to the leading dim of b ), ::strsmTestPrint() ); diff --git a/gtestsuite/testsuite/level3/trsm/test_trsm.h b/gtestsuite/testsuite/level3/trsm/test_trsm.h index 7145a92156..df0502b060 100644 --- a/gtestsuite/testsuite/level3/trsm/test_trsm.h +++ b/gtestsuite/testsuite/level3/trsm/test_trsm.h @@ -37,27 +37,25 @@ #include "trsm.h" #include "level3/ref_trsm.h" #include "inc/check_error.h" -#include "inc/utils.h" #include #include template -void test_trsm( char storage, char side, char uploa, char transa, - char diaga, gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, - gtint_t ldb_inc, double thresh, char datatype ) { - +void test_trsm( char storage, char side, char uploa, char transa, char diaga, + gtint_t m, gtint_t n, T alpha, gtint_t lda_inc, gtint_t ldb_inc, double thresh ) +{ gtint_t mn; testinghelpers::set_dim_with_side( side, m, n, &mn ); - gtint_t lda = testinghelpers::get_leading_dimension(storage, transa, mn, mn, lda_inc); - gtint_t ldb = testinghelpers::get_leading_dimension(storage, 'n', m, n, ldb_inc); + gtint_t lda = testinghelpers::get_leading_dimension( storage, transa, mn, mn, lda_inc ); + gtint_t ldb = testinghelpers::get_leading_dimension( storage, 'n', m, n, ldb_inc ); //---------------------------------------------------------- // Initialize matrics with random values. //---------------------------------------------------------- gtint_t lower = (diaga = 'n')||(diaga = 'N') ? 3 : 0; gtint_t upper = (diaga = 'n')||(diaga = 'N') ? 10 : 1; - std::vector a = testinghelpers::get_random_matrix(lower, upper, storage, transa, mn, mn, lda, datatype); - std::vector b = testinghelpers::get_random_matrix(3, 10, storage, 'n', m, n, ldb, datatype); + std::vector a = testinghelpers::get_random_matrix( lower, upper, storage, transa, mn, mn, lda ); + std::vector b = testinghelpers::get_random_matrix( 3, 10, storage, 'n', m, n, ldb ); // Making A diagonally dominant so that the condition number is good and // the algorithm doesn't diverge. @@ -68,7 +66,7 @@ void test_trsm( char storage, char side, char uploa, char transa, // Create a copy of v so that we can check reference results. std::vector b_ref(b); - mktrim( storage, uploa, mn, a.data(), lda ); + testinghelpers::make_triangular( storage, uploa, mn, a.data(), lda ); //---------------------------------------------------------- // Call BLIS function //---------------------------------------------------------- @@ -77,7 +75,8 @@ void test_trsm( char storage, char side, char uploa, char transa, //---------------------------------------------------------- // Call reference implementation. //---------------------------------------------------------- - testinghelpers::ref_trsm( storage, side, uploa, transa, diaga, m, n, alpha, a.data(), lda, b_ref.data(), ldb ); + testinghelpers::ref_trsm( storage, side, uploa, transa, diaga, m, n, alpha, a.data(), + lda, b_ref.data(), ldb ); //---------------------------------------------------------- // check component-wise error. diff --git a/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp b/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp index 1987251fc2..830b9081b5 100644 --- a/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp +++ b/gtestsuite/testsuite/level3/trsm/ztrsm_generic.cpp @@ -45,10 +45,10 @@ class ztrsmTest : gtint_t, dcomplex, gtint_t, - gtint_t, - char>> {}; + gtint_t>> {}; -TEST_P(ztrsmTest, RandomData) { +TEST_P(ztrsmTest, RandomData) +{ using T = dcomplex; //---------------------------------------------------------- // Initialize values from the parameters passed through @@ -76,22 +76,20 @@ TEST_P(ztrsmTest, RandomData) { // If increments are nonnegative, the array size is bigger than the matrix size. gtint_t lda_inc = std::get<8>(GetParam()); gtint_t ldb_inc = std::get<9>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<10>(GetParam()); // Set the threshold for the errors: - double thresh = std::max(m, n)*testinghelpers::getEpsilon(); + double thresh = (std::max)(m, n)*testinghelpers::getEpsilon(); //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh, datatype ); + test_trsm( storage, side, uploa, transa, diaga, m, n, alpha, lda_inc, ldb_inc, thresh ); } class ztrsmTestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { char sfm = std::get<0>(str.param); char side = std::get<1>(str.param); char uploa = std::get<2>(str.param); @@ -102,13 +100,12 @@ class ztrsmTestPrint { dcomplex alpha = std::get<7>(str.param); gtint_t lda_inc = std::get<8>(str.param); gtint_t ldb_inc = std::get<9>(str.param); - char datatype = std::get<10>(str.param); #ifdef TEST_BLAS std::string str_name = "ztrsm_"; #elif TEST_CBLAS std::string str_name = "cblas_ztrsm"; #else //#elif TEST_BLIS_TYPED - std::string str_name = "blis_ztrsm"; + std::string str_name = "bli_ztrsm"; #endif str_name = str_name + "_" + sfm+sfm+sfm; str_name = str_name + "_" + side + uploa + transa; @@ -120,7 +117,6 @@ class ztrsmTestPrint { str_name = str_name + "_a" + alpha_str; str_name = str_name + "_" + std::to_string(lda_inc); str_name = str_name + "_" + std::to_string(ldb_inc); - str_name = str_name + "_" + datatype; return str_name; } }; @@ -143,8 +139,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Range(gtint_t(10), gtint_t(11), 10), // n ::testing::Values(dcomplex{1.0,2.0}), // alpha ::testing::Values(gtint_t(0), gtint_t(2)), // increment to the leading dim of a - ::testing::Values(gtint_t(0), gtint_t(3)), // increment to the leading dim of b - ::testing::Values(ELEMENT_TYPE) // i : integer, f : float datatype type tested + ::testing::Values(gtint_t(0), gtint_t(3)) // increment to the leading dim of b ), ::ztrsmTestPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp new file mode 100644 index 0000000000..32386593d0 --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/dnrm2_extreme_values.cpp @@ -0,0 +1,266 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +class dnrm2_EVT : + public ::testing::TestWithParam> {}; + +TEST_P( dnrm2_EVT, EVT ) +{ + using T = double; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // index with extreme value iexval. + gtint_t i = std::get<2>(GetParam()); + T iexval = std::get<3>(GetParam()); + // index with extreme value jexval. + gtint_t j = std::get<4>(GetParam()); + T jexval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2(n, incx, i, iexval, j, jexval); +} + +// Prints the test case combination +class dnrm2_TestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + // vector length: + gtint_t n = std::get<0>(str.param); + // stride size for x: + gtint_t incx = std::get<1>(str.param); + // index with extreme value iexval. + gtint_t i = std::get<2>(str.param); + double iexval = std::get<3>(str.param); + // index with extreme value jexval. + gtint_t j = std::get<4>(str.param); + double jexval = std::get<5>(str.param); +#ifdef TEST_BLAS + std::string str_name = "dnrm2_"; +#elif TEST_CBLAS + std::string str_name = "cblas_dnrm2"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "bli_dnormfv"; +#endif + str_name = str_name + "_" + std::to_string(n); + std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + str_name = str_name + "_" + incx_str; + str_name = str_name + "_i" + std::to_string(i); + std::string iexval_str = testinghelpers::get_value_string(iexval); + str_name = str_name + "_" + iexval_str; + str_name = str_name + "_j" + std::to_string(j); + std::string jexval_str = testinghelpers::get_value_string(jexval); + str_name = str_name + "_" + jexval_str; + return str_name; + } +}; + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); + +/** + * dnrm2 implementation is composed by two parts: + * - vectorized path for n>4 + * - for-loop for multiples of 8 (F8) + * - for-loop for multiples of 4 (F4) + * - scalar path for n<=4 (S) + */ + +// Test for scalar path. +// Testing for jexval=1.0, means that we test only one NaN/Inf value. +// for jexval also being an extreme value, we test all combinations +// of having first a NaN and then an Inf and so on. +INSTANTIATE_TEST_SUITE_P( + scalar, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(3)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(0), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(2), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::dnrm2_TestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + vector_F8, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(8)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(3), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(6), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::dnrm2_TestPrint() + ); + +// To test the second for-loop (F4), we use n = 12 +// and ensure that the extreme values are on or after index 8. +INSTANTIATE_TEST_SUITE_P( + vector_F4, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(12)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(9), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(11), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::dnrm2_TestPrint() + ); + +// Now let's check the combination of a vectorized path and +// the scalar path, by putting an extreme value in each +// to check that the checks are integrated correctly. +INSTANTIATE_TEST_SUITE_P( + vector_scalar, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(10)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(5), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(8), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::dnrm2_TestPrint() + ); + +// Multithreading unit tester +/* + The following instantiator has data points that would suffice + the unit testing with <= 64 threads. + + Sizes from 256 to 259 ensure that each thread gets a minimum + size of 4, with some sizes inducing fringe cases. + + Sizes from 512 to 515 ensure that each thread gets a minimum + size of 8, with some sizes inducing fringe cases. + + Sizes from 768 to 771 ensure that each thread gets a minimum + size of 12, with some sizes inducing fringe cases. + + NOTE : Extreme values are induced at indices that are valid + for all the listed sizes in the instantiator. + + Non-unit strides are also tested, since they might get packed. +*/ +INSTANTIATE_TEST_SUITE_P( + EVT_MT_Unit_Tester, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(256), + gtint_t(257), + gtint_t(258), + gtint_t(259), + gtint_t(512), + gtint_t(513), + gtint_t(514), + gtint_t(515), + gtint_t(768), + gtint_t(769), + gtint_t(770), + gtint_t(771)), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5)), + // i : index of x that has value iexval + ::testing::Values(0, 5, 100, 255), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(4, 17, 125, 201), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::dnrm2_TestPrint() + ); + +// Instantiator if AOCL_DYNAMIC is enabled +/* + The instantiator here checks for correctness of + the compute with sizes large enough to bypass + the thread setting logic with AOCL_DYNAMIC enabled +*/ +INSTANTIATE_TEST_SUITE_P( + EVT_MT_AOCL_DYNAMIC, + dnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(2950000), + gtint_t(2950001), + gtint_t(2950002), + gtint_t(2950003) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5)), + // i : index of x that has value iexval + ::testing::Values(1000000, 2000000), + // iexval + ::testing::Values(NaN, Inf), + ::testing::Values(1500000, 2500000), + ::testing::Values(-Inf, NaN) + ), + ::dnrm2_TestPrint() + ); diff --git a/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp index 245b5f49ac..422f5bfe76 100644 --- a/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/dnrm2_generic.cpp @@ -9,14 +9,14 @@ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -36,7 +36,7 @@ #include "test_nrm2.h" class dnrm2Test : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; TEST_P( dnrm2Test, RandomData ) { @@ -49,8 +49,6 @@ TEST_P( dnrm2Test, RandomData ) gtint_t n = std::get<0>(GetParam()); // stride size for x: gtint_t incx = std::get<1>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<2>(GetParam()); // Set the threshold for the errors: double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); @@ -58,17 +56,16 @@ TEST_P( dnrm2Test, RandomData ) //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_nrm2(n, incx, thresh, datatype); + test_nrm2( n, incx, thresh ); } // Prints the test case combination class dnrm2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); gtint_t incx = std::get<1>(str.param); - char datatype = std::get<2>(str.param); #ifdef TEST_BLAS std::string str_name = "dnrm2_"; #elif TEST_CBLAS @@ -79,23 +76,117 @@ class dnrm2TestPrint { str_name = str_name + "_" + std::to_string(n); std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + datatype; return str_name; } }; -// Black box testing. +/** + * dnrm2 implementation is composed by two parts: + * - vectorized path for n>4 + * - for-loop for multiples of 8 (F8) + * - for-loop for multiples of 4 (F4) + * - scalar path for n<=4 (S) +*/ + +INSTANTIATE_TEST_SUITE_P( + AT_1T, + dnrm2Test, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(1), // trivial case n=1 + gtint_t(3), // will go through SSE and scalar + gtint_t(8), // 1*8 - will only go through F8 + gtint_t(24), // 3*8 - will go through F8 + gtint_t(34), // 4*8 + 2 - will go through F8 & S + gtint_t(52), // 6*8 + 4 - will go through F8 & F4 + gtint_t(71), // 8*8 + 4 + 3 - will go through F8 & F4 & S + gtint_t(89), // a few bigger numbers + gtint_t(122), + gtint_t(185), + gtint_t(217) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) +#ifndef TEST_BLIS_TYPED + , gtint_t(-1), gtint_t(-7) +#endif + ) + ), + ::dnrm2TestPrint() + ); + +// Multithreading unit tester +/* + NOTE : The following instantiator is the most useful if BLIS + configured with aocl-dynamic disabled, since then it + would be sufficient to verify functionality upto 64 + threads. + + The following instantiator has data points that would suffice + the extreme value testing with <= 64 threads. + + Sizes from 256 to 259 ensure that each thread gets a minimum + size of 4, with some sizes inducing fringe cases. + + Sizes from 512 to 515 ensure that each thread gets a minimum + size of 8, with some sizes inducing fringe cases. + + Sizes from 768 to 771 ensure that each thread gets a minimum + size of 12( i.e 8-block loop + 4-block loop), with some sizes + inducing fringe cases. + + Non-unit strides are also tested, since they might get packed. +*/ +INSTANTIATE_TEST_SUITE_P( + AT_MT_Unit_Tester, + dnrm2Test, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(256), + gtint_t(257), + gtint_t(258), + gtint_t(259), + gtint_t(512), + gtint_t(513), + gtint_t(514), + gtint_t(515), + gtint_t(768), + gtint_t(769), + gtint_t(770), + gtint_t(771) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) +#ifndef TEST_BLIS_TYPED + , gtint_t(-1), gtint_t(-7) +#endif + ) + ), + ::dnrm2TestPrint() + ); + +// Instantiator if AOCL_DYNAMIC is enabled +/* + The instantiator here checks for correctness of + the compute with sizes large enough to bypass + the thread setting logic with AOCL_DYNAMIC enabled +*/ INSTANTIATE_TEST_SUITE_P( - Blackbox, + AT_MT_AOCL_DYNAMIC, dnrm2Test, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(2) + // m size of vector + ::testing::Values(gtint_t(2950000), + gtint_t(2950001), + gtint_t(2950002), + gtint_t(2950003) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) #ifndef TEST_BLIS_TYPED - ,gtint_t(-1), gtint_t(-2) + , gtint_t(-1), gtint_t(-7) #endif - ), // stride size for x - ::testing::Values('i') // i : integer, f : float datatype type tested + ) ), ::dnrm2TestPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp new file mode 100644 index 0000000000..993859265c --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/dznrm2_extreme_values.cpp @@ -0,0 +1,264 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +class dznrm2_EVT : + public ::testing::TestWithParam>{}; + +TEST_P( dznrm2_EVT, EVT ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // index with extreme value iexval. + gtint_t i = std::get<2>(GetParam()); + T iexval = std::get<3>(GetParam()); + // index with extreme value jexval. + gtint_t j = std::get<4>(GetParam()); + T jexval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2(n, incx, i, iexval, j, jexval); +} + +// Prints the test case combination +class dznrm2_TestPrint{ +public: + std::string operator()( + testing::TestParamInfo> str) const { + // vector length: + gtint_t n = std::get<0>(str.param); + // stride size for x: + gtint_t incx = std::get<1>(str.param); + // index with extreme value iexval. + gtint_t i = std::get<2>(str.param); + dcomplex iexval = std::get<3>(str.param); + // index with extreme value jexval. + gtint_t j = std::get<4>(str.param); + dcomplex jexval = std::get<5>(str.param); +#ifdef TEST_BLAS + std::string str_name = "dznrm2_"; +#elif TEST_CBLAS + std::string str_name = "cblas_dznrm2"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "bli_znormfv"; +#endif + str_name = str_name + "_" + std::to_string(n); + std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + str_name = str_name + "_" + incx_str; + str_name = str_name + "_i" + std::to_string(i); + std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag); + str_name = str_name + iexval_str; + str_name = str_name + "_j" + std::to_string(j); + std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag); + str_name = str_name + jexval_str; + return str_name; + } +}; + +static double NaN = std::numeric_limits::quiet_NaN(); +static double Inf = std::numeric_limits::infinity(); +/** + * dznrm2 implementation is composed by two parts: + * - vectorized path for n>2 + * - for-loop for multiples of 4 (F4) + * - for-loop for multiples of 2 (F2) + * - scalar path for n<=2 (S) +*/ + +// Test for scalar path. +// Testing for jexval=(1.0, 2.0), means that we test only one NaN/Inf value. +// for jexval also being an extreme value, we test all combinations +// of having first a NaN and then an Inf and so on. +INSTANTIATE_TEST_SUITE_P( + scalar, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(2)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(0), + // iexval + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}, dcomplex{NaN, Inf}, dcomplex{Inf, NaN}), + ::testing::Values(1), + ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) + ), + ::dznrm2_TestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + vector_F4, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(4)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(1), + // iexval + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}, dcomplex{NaN, Inf}, dcomplex{Inf, NaN}), + ::testing::Values(3), + ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) + ), + ::dznrm2_TestPrint() + ); + +// To test the second for-loop (F2), we use n = 6 +// and ensure that the extreme values are on or after index 4. +INSTANTIATE_TEST_SUITE_P( + vector_F2, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(6)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(4), + // iexval + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}, dcomplex{NaN, Inf}, dcomplex{Inf, NaN}), + ::testing::Values(5), + ::testing::Values(dcomplex{1.0, 2.0}, dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) + ), + ::dznrm2_TestPrint() + ); + +// Now let's check the combination of a vectorized path and +// the scalar path, by putting an extreme value in each +// to check that the checks are integrated correctly. +INSTANTIATE_TEST_SUITE_P( + vector_scalar, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(7)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(2), + // iexval + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}, dcomplex{NaN, Inf}, dcomplex{Inf, NaN}), + ::testing::Values(6), + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) + ), + ::dznrm2_TestPrint() + ); + +// Mutlthreading Unit Tester +/* + The following instantiator has data points that would suffice + the extreme value testing with 64 threads. + + Sizes 128 and 129 ensure that each thread gets size 2, with + the first thread dealing with fringe case also, if required. + + Sizes 256, 257 and 259 ensure that each thread gets size 4, with + the first two threads dealing wtih extra AVX and SSE cases also, + if required. + + Sizes from 384 to 389 ensure that each thread gets size 6, with + the first few threads dealing with extra AVX and SSE cases if needed. + + NOTE : Extreme values are induced at indices that are valid + for all the listed sizes in the instantiator. + + Non-unit strides are also tested, since they might get packed +*/ +INSTANTIATE_TEST_SUITE_P( + EVT_MT_Unit_Tester, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(128), + gtint_t(129), + gtint_t(256), + gtint_t(257), + gtint_t(259), + gtint_t(384), + gtint_t(385), + gtint_t(386), + gtint_t(387), + gtint_t(388), + gtint_t(389) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3)), + // i : index of x that has value iexval + ::testing::Values(2, 17, 65, 110), + // iexval + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}, dcomplex{NaN, Inf}, dcomplex{Inf, NaN}), + ::testing::Values(6, 25, 64, 127), + ::testing::Values(dcomplex{NaN, 1.0}, dcomplex{Inf, 9.0}, dcomplex{-1.0, -Inf}, dcomplex{2.0, NaN}) + ), + ::dznrm2_TestPrint() + ); + +// Instantiator if AOCL_DYNAMIC is enabled +/* + The instantiator here checks for correctness of + the compute with sizes large enough to bypass + the thread setting logic with AOCL_DYNAMIC enabled +*/ +INSTANTIATE_TEST_SUITE_P( + EVT_MT_AOCL_DYNAMIC, + dznrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(1530000), + gtint_t(1530001) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(5)), + // i : index of x that has value iexval + ::testing::Values(800000, 1000000), + // iexval + ::testing::Values(dcomplex{NaN, Inf}, dcomplex{-Inf, NaN}, dcomplex{Inf, 0.0}), + ::testing::Values(1100000, 1500000), + ::testing::Values(dcomplex{NaN, Inf}, dcomplex{-Inf, NaN}, dcomplex{Inf, 0.0}) + ), + ::dznrm2_TestPrint() + ); diff --git a/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp new file mode 100644 index 0000000000..a0fb186ccc --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/dznrm2_generic.cpp @@ -0,0 +1,183 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +class dznrm2Test : + public ::testing::TestWithParam> {}; + +TEST_P( dznrm2Test, RandomData ) +{ + using T = dcomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + + // Set the threshold for the errors: + double thresh = 3*testinghelpers::getEpsilon(); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2(n, incx, thresh); +} + +// Prints the test case combination +class dznrm2TestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + gtint_t n = std::get<0>(str.param); + gtint_t incx = std::get<1>(str.param); +#ifdef TEST_BLAS + std::string str_name = "dznrm2_"; +#elif TEST_CBLAS + std::string str_name = "cblas_dznrm2"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "bli_znormfv"; +#endif + str_name = str_name + "_" + std::to_string(n); + std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + str_name = str_name + "_" + incx_str; + return str_name; + } +}; + +/** + * dznrm2 implementation is composed by two parts: + * - vectorized path for n>2 + * - for-loop for multiples of 4 (F4) + * - for-loop for multiples of 2 (F2) + * - scalar path for n<=2 (S) +*/ +INSTANTIATE_TEST_SUITE_P( + AT_1T, + dznrm2Test, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(1), // trivial case n=1 + gtint_t(2), // 1*2 - will only go through F2 + gtint_t(4), // 1*4 - will only go through F4 + gtint_t(12), // 3*4 - will go through F4 + gtint_t(17), // 4*4 + 1 - will go through F4 & S + gtint_t(22), // 5*4 + 2 - will go through F4 & F2 + gtint_t(35), // 8*4 + 2 + 1 - will go through F4 & F2 & S + gtint_t(78), // a few bigger numbers + gtint_t(112), + gtint_t(187), + gtint_t(213) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) +#ifndef TEST_BLIS_TYPED + , gtint_t(-1), gtint_t(-7) +#endif + ) + ), + ::dznrm2TestPrint() + ); + +// Multithreading unit tester +/* + The following instantiator has data points that would suffice + the unit testing with 64 threads. + + Sizes 128 and 129 ensure that each thread gets a minimum + size of 2, with some sizes inducing fringe cases. + + Sizes 256, 257 and 259 ensure that each thread gets a minimum + size of 4, with some sizes inducing fringe cases. + + Sizes from 384 to 389 ensure that each thread gets a minimum + size of 6( 4-block loop + 2-block loop), with some sizes inducing + fringe cases. + + Non-unit strides are also tested, since they might get packed. +*/ +INSTANTIATE_TEST_SUITE_P( + AT_MT_Unit_Tester, + dznrm2Test, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(128), + gtint_t(129), + gtint_t(256), + gtint_t(257), + gtint_t(259), + gtint_t(384), + gtint_t(385), + gtint_t(386), + gtint_t(387), + gtint_t(388), + gtint_t(389) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) +#ifndef TEST_BLIS_TYPED + , gtint_t(-1), gtint_t(-7) +#endif + ) + ), + ::dznrm2TestPrint() + ); + +// Instantiator if AOCL_DYNAMIC is enabled +/* + The instantiator here checks for correctness of + the compute with sizes large enough to bypass + the thread setting logic with AOCL_DYNAMIC enabled +*/ +INSTANTIATE_TEST_SUITE_P( + AT_MT_AOCL_DYNAMIC, + dznrm2Test, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(1530000), + gtint_t(1530001) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) +#ifndef TEST_BLIS_TYPED + , gtint_t(-1), gtint_t(-7) +#endif + ) + ), + ::dznrm2TestPrint() + ); diff --git a/gtestsuite/testsuite/util/nrm2/nrm2.h b/gtestsuite/testsuite/util/nrm2/nrm2.h index 9d54d51f65..9693a70aa0 100644 --- a/gtestsuite/testsuite/util/nrm2/nrm2.h +++ b/gtestsuite/testsuite/util/nrm2/nrm2.h @@ -38,17 +38,24 @@ #include "common/testing_helpers.h" /** - * @brief Overload bli_*normfv() functions using typed_nrm2. - * Will be used in testing and especially in TYPED_TESTs. - * Computes the Euclidean norm of x. + * @brief Computes the Euclidean norm of x. + * + * Euclidean norm of a vector x is defined as nrm2 = sqrt(x'*x). + * In case a vector element is NaN, nrm2 must be NaN. + * In case a vector element is inf, and there is no element which is NaN, nrm2 must be inf. + * If n <= 0, nrm2 returns zero. + * If incx = 0, nrm2 returns sqrt(n*abs(x[0])**2). + * * @param[in] n vector length * @param[in] x pointer which points to the first element of x * @param[in] incx increment of x * @return the Euclidean norm of x + * + * */ -template -static Treal nrm2_(gtint_t n, T* x, gtint_t incx){ +template::real_type> +static RT nrm2_(gtint_t n, T* x, gtint_t incx){ if constexpr (std::is_same::value) return snrm2_( &n, x, &incx ); else if constexpr (std::is_same::value) @@ -61,8 +68,8 @@ static Treal nrm2_(gtint_t n, T* x, gtint_t incx){ throw std::runtime_error("Error in testsuite/level1/nrm2.h: Invalid typename in nrm2_()."); } -template -static Treal cblas_nrm2(gtint_t n, T* x, gtint_t incx){ +template::real_type> +static RT cblas_nrm2(gtint_t n, T* x, gtint_t incx){ if constexpr (std::is_same::value) return cblas_snrm2( n, x, incx ); else if constexpr (std::is_same::value) @@ -75,9 +82,9 @@ static Treal cblas_nrm2(gtint_t n, T* x, gtint_t incx){ throw std::runtime_error("Error in testsuite/level1/nrm2.h: Invalid typename in cblas_nrm2()."); } -template -static Treal typed_nrm2(gtint_t n, T* x, gtint_t incx){ - Treal nrm; +template::real_type> +static RT typed_nrm2(gtint_t n, T* x, gtint_t incx){ + RT nrm; if constexpr (std::is_same::value) bli_snormfv(n, x, incx, &nrm); else if constexpr (std::is_same::value) @@ -91,15 +98,15 @@ static Treal typed_nrm2(gtint_t n, T* x, gtint_t incx){ return nrm; } -template -static Treal nrm2(gtint_t n, T* x, gtint_t incx) +template::real_type> +static RT nrm2(gtint_t n, T* x, gtint_t incx) { #ifdef TEST_BLAS - return nrm2_(n, x, incx); + return nrm2_(n, x, incx); #elif TEST_CBLAS - return cblas_nrm2(n, x, incx); + return cblas_nrm2(n, x, incx); #elif TEST_BLIS_TYPED - return typed_nrm2(n, x, incx); + return typed_nrm2(n, x, incx); #else throw std::runtime_error("Error in testsuite/level1/axpyv.h: No interfaces are set to be tested."); #endif diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp new file mode 100644 index 0000000000..c4e09cd83e --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/nrm2_corner_cases.cpp @@ -0,0 +1,130 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +/** + * Testing edge input parameters. + * + * zero n should return 0. + * zero incx should return sqrt(n*abs(x[0])**2). +*/ + +// Early return. +template +class nrm2_ERS : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(nrm2_ERS, TypeParam); + +TYPED_TEST(nrm2_ERS, zero_n) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 0; + gtint_t incx = 1; + // initialize norm to ensure that it is set to zero from nrm2 and it does not simply return. + RT blis_norm = 19.0; + // using nullptr since x should not be accessed anyway. + // If "x" is accessed before return then nrm2 would segfault. + blis_norm = nrm2(n, nullptr, incx); + RT ref_norm = testinghelpers::ref_nrm2(n, nullptr, incx); + computediff(blis_norm, ref_norm); +} + +// Edge case where it actually does not return early. +// Since there are 2 different paths, vectorized and scalar, +// we break this into 2 tests, once for each case. +template +class nrm2_EIC : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(nrm2_EIC, TypeParam); + +TYPED_TEST(nrm2_EIC, zero_incx_scalar) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2; + gtint_t incx = 0; + std::vector x(n); + for (auto &xi : x) + testinghelpers::initone(xi); + // For incx=0, nrm2 iterates through the first element n-times. + // So, we initialize x[0] with a different value than the rest + // of the elements. + x[0] = T{2.0}*x[0]; + RT blis_norm = 19.0; + blis_norm = nrm2(n, x.data(), incx); + RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); + computediff(blis_norm, ref_norm); +} + +TYPED_TEST(nrm2_EIC, zero_incx_vectorized) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 64; + gtint_t incx = 0; + std::vector x(n); + for (auto &xi : x) + testinghelpers::initone(xi); + // For incx=0, nrm2 iterates through the first element n-times. + // So, we initialize x[0] with a different value than the rest + // of the elements. + x[0] = T{2.0}*x[0]; + RT blis_norm = 19.0; + blis_norm = nrm2(n, x.data(), incx); + RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); + computediff(blis_norm, ref_norm); +} + +/* + The following test is specific to dnrm2 and dznrm2 apis. + In case of multithreading, each thread will calculate its + norm based on the data it operates on. All these norms will + be reduced post the parallel region. +*/ +TYPED_TEST( nrm2_EIC, zero_incx_MT ) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2950000; + gtint_t incx = 0; + std::vector x(n); + for (auto &xi : x) + testinghelpers::initone(xi); + // For incx=0, nrm2 iterates through the first element n-times. + // So, we initialize x[0] with a different value than the rest + // of the elements. + x[0] = T{2.0}*x[0]; + RT blis_norm = nrm2(n, x.data(), incx); + RT ref_norm = testinghelpers::ref_nrm2(n, x.data(), incx); + computediff(blis_norm, ref_norm); +} diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_extreme_vals.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp similarity index 67% rename from gtestsuite/testsuite/util/nrm2/nrm2_extreme_vals.cpp rename to gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp index 5bd2bb46e6..3a702de62b 100644 --- a/gtestsuite/testsuite/util/nrm2/nrm2_extreme_vals.cpp +++ b/gtestsuite/testsuite/util/nrm2/nrm2_invalid_inputs.cpp @@ -34,44 +34,28 @@ #include #include "test_nrm2.h" +#include "common/wrong_inputs_helpers.h" +/** + * Testing invalid/incorrect input parameters. + * + * That is only negative n for this API. Zero incx and zero n is allowed. +*/ template -class xnrm2 : public ::testing::Test {}; -typedef ::testing::Types TypeParam; -TYPED_TEST_SUITE(xnrm2, TypeParam); - -TYPED_TEST(xnrm2, zeroFP) { - using T = TypeParam; - T x = T(0); - - T norm = nrm2(1, &x, 1); - EXPECT_EQ(0, norm); -} - -TYPED_TEST(xnrm2, minFP) { - using T = TypeParam; - T x = std::numeric_limits::min(); - - T norm = nrm2(1, &x, 1); - EXPECT_EQ(x, norm); -} - -TYPED_TEST(xnrm2, maxFP) { - using T = TypeParam; - T x = std::numeric_limits::max(); - - T norm = nrm2(1, &x, 1); - EXPECT_EQ(x, norm); -} - -TEST(dnrm2, largeDouble) { - using T = double; - gtint_t n = 2; - std::vector x{3e300, 4e300}, y{-4e300, -3e300}; - - T norm = nrm2(n, x.data(), 1); - EXPECT_EQ(5e300, norm); - - norm = nrm2(n, y.data(), 1); - EXPECT_EQ(5e300, norm); +class nrm2_IIT : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(nrm2_IIT, TypeParam); + +// Adding namespace to get default parameters from testinghelpers/common/wrong_input_helpers.h. +using namespace testinghelpers::IIT; + +TYPED_TEST(nrm2_IIT, negative_n) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + T x = T{-3.7}; + // initialize blis norm with garbage. + RT blis_norm = -4.2; + blis_norm = nrm2(-2, &x, INC); + + computediff(blis_norm, 0.0); } diff --git a/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp b/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp new file mode 100644 index 0000000000..22e0141292 --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/nrm2_underflow_overflow.cpp @@ -0,0 +1,177 @@ +#include +#include "test_nrm2.h" + +template +class OUT_nrm2 : public ::testing::Test {}; +typedef ::testing::Types TypeParam; +TYPED_TEST_SUITE(OUT_nrm2, TypeParam); + +// Testing for max representable number to see if overflow is handled correctly. +TYPED_TEST(OUT_nrm2, maxFP_scalar) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + + RT maxval = (std::numeric_limits::max)(); + T x = T{maxval}; + + RT norm = nrm2(1, &x, 1); + computediff(maxval, norm); +} +TYPED_TEST(OUT_nrm2, maxFP_vectorized) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 64; + std::vector x(n, T{0}); + RT maxval = (std::numeric_limits::max)(); + x[17] = T{maxval}; + RT norm = nrm2(n, x.data(), 1); + computediff(maxval, norm); +} + +// Testing for min representable number to see if underflow is handled correctly. +TYPED_TEST(OUT_nrm2, minFP_scalar) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + + RT minval = (std::numeric_limits::min)(); + T x = T{minval}; + RT norm = nrm2(1, &x, 1); + computediff(minval, norm); +} +TYPED_TEST(OUT_nrm2, minFP_vectorized) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 64; + std::vector x(n, T{0}); + RT minval = (std::numeric_limits::min)(); + x[17] = T{minval}; + RT norm = nrm2(n, x.data(), 1); + computediff(minval, norm); +} + +// Since there are 2 different paths, vectorized and scalar, +// we break this into 2 tests, once for each case. +TYPED_TEST(OUT_nrm2, zeroFP_scalar) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + T x = T{0}; + + RT norm = nrm2(1, &x, 1); + computediff(0, norm); +} +TYPED_TEST(OUT_nrm2, zeroFP_vectorized) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 64; + std::vector x(n, T{0}); + + RT norm = nrm2(n, x.data(), 1); + computediff(0, norm); +} + +/* + Adding a type-parameterized test to check for + overflow and underflow handling with multiple threads + in case of dnrm2 and dznrm2. Can also be used if snrm2 + and scnrm2 are multithreaded. +*/ + +// Checking only for overflow, based on the threshold +TYPED_TEST( OUT_nrm2, OFlow_MT ) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2950000; + std::vector x(n, T{1.0}); // A normal value + RT bigval; + if constexpr ( std::is_same::value ) + { + bigval = powf( ( float )FLT_RADIX, floorf( ( FLT_MAX_EXP - 23) * 0.5f ) ) * ( 1.0f + FLT_EPSILON ); + } + else + { + bigval = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ) * ( 1.0 + DBL_EPSILON ); + } + + // Set the threshold for the errors: + double thresh = 2*testinghelpers::getEpsilon(); + x[1000] = T{ bigval }; + x[50000] = T{ bigval }; + x[151001] = T{ bigval }; + x[2949999] = T{ bigval }; + + RT norm = nrm2( n, x.data(), 1 ); + RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); + computediff( norm, ref_norm, thresh ); +} + +// Checking only for underflow, based on the threshold +TYPED_TEST( OUT_nrm2, UFlow_MT ) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2950000; + std::vector x(n, T{1.0}); // A normal value + RT smlval; + if constexpr ( std::is_same::value ) + { + smlval = powf( ( float )FLT_RADIX, ceilf( ( FLT_MIN_EXP - 1 ) * 0.5f ) ) * ( 1.0f - FLT_EPSILON ); + } + else + { + smlval = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ) * ( 1.0 - DBL_EPSILON ); + } + + // Set the threshold for the errors: + double thresh = 2*testinghelpers::getEpsilon(); + x[1000] = T{ smlval }; + x[50000] = T{ smlval }; + x[151001] = T{ smlval }; + x[2949999] = T{ smlval }; + + RT norm = nrm2( n, x.data(), 1 ); + RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); + computediff( norm, ref_norm, thresh ); +} + +// Checking for both overflow and underflow, based on the thresholds +TYPED_TEST( OUT_nrm2, OUFlow_MT ) { + using T = TypeParam; + using RT = typename testinghelpers::type_info::real_type; + gtint_t n = 2950000; + std::vector x(n, T{1.0}); // A normal value + RT bigval, smlval; + if constexpr ( std::is_same::value ) + { + bigval = powf( ( float )FLT_RADIX, floorf( ( FLT_MAX_EXP - 23) * 0.5f ) ) * ( 1.0f + FLT_EPSILON ); + smlval = powf( ( float )FLT_RADIX, ceilf( ( FLT_MIN_EXP - 1 ) * 0.5f ) ) * ( 1.0f - FLT_EPSILON ); + } + else + { + bigval = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ) * ( 1.0 + DBL_EPSILON ); + smlval = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ) * ( 1.0 - DBL_EPSILON ); + } + + // Set the threshold for the errors: + double thresh = 2*testinghelpers::getEpsilon(); + x[1000] = T{ smlval }; + x[50000] = T{ bigval }; + x[151001] = T{ bigval }; + x[2949999] = T{ smlval }; + + RT norm = nrm2( n, x.data(), 1 ); + RT ref_norm = testinghelpers::ref_nrm2( n, x.data(), 1 ); + computediff( norm, ref_norm, thresh ); +} + +// Specific test case used by an ISV. +// Checks for overflow. +TEST(dnrm2, largeDouble) { + using T = double; + gtint_t n = 2; + std::vector x{3e300, 4e300}, y{-4e300, -3e300}; + + T norm = nrm2(n, x.data(), 1); + computediff(5e300, norm); + + norm = nrm2(n, y.data(), 1); + computediff(5e300, norm); +} diff --git a/gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp new file mode 100644 index 0000000000..52ba4f8647 --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/scnrm2_extreme_values.cpp @@ -0,0 +1,211 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +class scnrm2_EVT : + public ::testing::TestWithParam>{}; + +TEST_P( scnrm2_EVT, EVT ) +{ + using T = scomplex; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // index with extreme value iexval. + gtint_t i = std::get<2>(GetParam()); + T iexval = std::get<3>(GetParam()); + // index with extreme value jexval. + gtint_t j = std::get<4>(GetParam()); + T jexval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2(n, incx, i, iexval, j, jexval); +} + +// Prints the test case combination +class scnrm2_TestPrint{ +public: + std::string operator()( + testing::TestParamInfo> str) const { + // vector length: + gtint_t n = std::get<0>(str.param); + // stride size for x: + gtint_t incx = std::get<1>(str.param); + // index with extreme value iexval. + gtint_t i = std::get<2>(str.param); + scomplex iexval = std::get<3>(str.param); + // index with extreme value jexval. + gtint_t j = std::get<4>(str.param); + scomplex jexval = std::get<5>(str.param); +#ifdef TEST_BLAS + std::string str_name = "scnrm2_"; +#elif TEST_CBLAS + std::string str_name = "cblas_scnrm2"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "bli_cnormfv"; +#endif + str_name = str_name + "_" + std::to_string(n); + std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + str_name = str_name + "_" + incx_str; + str_name = str_name + "_i" + std::to_string(i); + std::string iexval_str = "_Re_" + testinghelpers::get_value_string(iexval.real) + "_Im_" + testinghelpers::get_value_string(iexval.imag); + str_name = str_name + iexval_str; + str_name = str_name + "_j" + std::to_string(j); + std::string jexval_str = "_Re_" + testinghelpers::get_value_string(jexval.real) + "_Im_" + testinghelpers::get_value_string(jexval.imag); + str_name = str_name + jexval_str; + return str_name; + } +}; + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); +/** + * scnrm2 implementation is composed by two parts: + * - vectorized path for n>=64 + * - for-loop for multiples of 16 (F16) + * - for-loop for multiples of 12 (F12) + * - for-loop for multiples of 8 (F8) + * - scalar path for n<64 (S) +*/ + +// Test for scalar path. +// Testing for jexval=(1.0, 2.0), means that we test only one NaN/Inf value. +// for jexval also being an extreme value, we test all combinations +// of having first a NaN and then an Inf and so on. +INSTANTIATE_TEST_SUITE_P( + scalar, + scnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(2)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(0), + // iexval + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), + ::testing::Values(1), + ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) + ), + ::scnrm2_TestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + vector_F16, + scnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(64)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(10), + // iexval + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), + ::testing::Values(30), + ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) + ), + ::scnrm2_TestPrint() + ); + +// To test the second for-loop (F12), we use n = 76 = 4*16+12 +// and ensure that the extreme values are on or after index 64. +INSTANTIATE_TEST_SUITE_P( + vector_F12, + scnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(76)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(68), + // iexval + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), + ::testing::Values(70), + ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) + ), + ::scnrm2_TestPrint() + ); + +// To test the second for-loop (F8), we use n = 72 = 4*16+8 +// and ensure that the extreme values are on or after index 64. +INSTANTIATE_TEST_SUITE_P( + vector_F8, + scnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(72)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(66), + // iexval + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), + ::testing::Values(70), + ::testing::Values(scomplex{1.0, 2.0}, scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) + ), + ::scnrm2_TestPrint() + ); + +// Now let's check the combination of a vectorized path and +// the scalar path, by putting an extreme value in each +// to check that the checks are integrated correctly. +INSTANTIATE_TEST_SUITE_P( + vector_scalar, + scnrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(79)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(25), + // iexval + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}, scomplex{NaN, Inf}, scomplex{Inf, NaN}), + ::testing::Values(68), + ::testing::Values(scomplex{NaN, 1.0}, scomplex{Inf, 9.0}, scomplex{-1.0, -Inf}, scomplex{2.0, NaN}) + ), + ::scnrm2_TestPrint() + ); + diff --git a/gtestsuite/testsuite/util/nrm2/cnrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp similarity index 68% rename from gtestsuite/testsuite/util/nrm2/cnrm2_generic.cpp rename to gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp index a020075f2c..d27f5c50b5 100644 --- a/gtestsuite/testsuite/util/nrm2/cnrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/scnrm2_generic.cpp @@ -35,10 +35,10 @@ #include #include "test_nrm2.h" -class CNrm2Test : - public ::testing::TestWithParam> {}; +class scnrm2Test : + public ::testing::TestWithParam> {}; -TEST_P( CNrm2Test, RandomData ) +TEST_P( scnrm2Test, RandomData ) { using T = scomplex; //---------------------------------------------------------- @@ -49,8 +49,6 @@ TEST_P( CNrm2Test, RandomData ) gtint_t n = std::get<0>(GetParam()); // stride size for x: gtint_t incx = std::get<1>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<2>(GetParam()); // Set the threshold for the errors: double thresh = std::sqrt(n)*testinghelpers::getEpsilon(); @@ -58,17 +56,16 @@ TEST_P( CNrm2Test, RandomData ) //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_nrm2(n, incx, thresh, datatype); + test_nrm2(n, incx, thresh); } // Prints the test case combination -class CNrm2TestPrint { +class scnrm2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); gtint_t incx = std::get<1>(str.param); - char datatype = std::get<2>(str.param); #ifdef TEST_BLAS std::string str_name = "scnrm2_"; #elif TEST_CBLAS @@ -79,23 +76,41 @@ class CNrm2TestPrint { str_name = str_name + "_" + std::to_string(n); std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + datatype; return str_name; } }; -// Black box testing. +/** + * scnrm2 implementation is composed by two parts: + * - vectorized path for n>=64 + * - for-loop for multiples of 16 (F16) + * - for-loop for multiples of 12 (F12) + * - for-loop for multiples of 8 (F8) + * - scalar path for n<64 (S) +*/ INSTANTIATE_TEST_SUITE_P( - Blackbox, - CNrm2Test, + AT, + scnrm2Test, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(2) + // m size of vector + ::testing::Values(gtint_t(1), // trivial case n=1 + gtint_t(35), // will only go through S + gtint_t(64), // 4*16 - will only go through F16 + gtint_t(67), // 4*16 + 3 - will go through F16 & S + gtint_t(72), // 4*16 + 8 - will go through F16 & F8 + gtint_t(75), // 4*16 + 8 + 3 - will go through F16 & F8 & S + gtint_t(76), // 4*16 + 12 - will go through F16 & F12 + gtint_t(78), // 4*16 + 12 + 2 - will go through F16 & F12 & S + gtint_t(112), // a few bigger numbers + gtint_t(187), + gtint_t(213) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) #ifndef TEST_BLIS_TYPED - , gtint_t(-1), gtint_t(-2) + , gtint_t(-1), gtint_t(-7) #endif - ), // stride size for x - ::testing::Values('i') // i : integer, f : float datatype type tested + ) ), - ::CNrm2TestPrint() + ::scnrm2TestPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp b/gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp new file mode 100644 index 0000000000..5bfa83a346 --- /dev/null +++ b/gtestsuite/testsuite/util/nrm2/snrm2_extreme_values.cpp @@ -0,0 +1,215 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "test_nrm2.h" + +class snrm2_EVT : + public ::testing::TestWithParam> {}; + +TEST_P( snrm2_EVT, EVT ) +{ + using T = float; + //---------------------------------------------------------- + // Initialize values from the parameters passed through + // test suite instantiation (INSTANTIATE_TEST_SUITE_P). + //---------------------------------------------------------- + // vector length: + gtint_t n = std::get<0>(GetParam()); + // stride size for x: + gtint_t incx = std::get<1>(GetParam()); + // index with extreme value iexval. + gtint_t i = std::get<2>(GetParam()); + T iexval = std::get<3>(GetParam()); + // index with extreme value jexval. + gtint_t j = std::get<4>(GetParam()); + T jexval = std::get<5>(GetParam()); + + //---------------------------------------------------------- + // Call test body using these parameters + //---------------------------------------------------------- + test_nrm2(n, incx, i, iexval, j, jexval); +} + +// Prints the test case combination +class snrm2_TestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const { + // vector length: + gtint_t n = std::get<0>(str.param); + // stride size for x: + gtint_t incx = std::get<1>(str.param); + // index with extreme value iexval. + gtint_t i = std::get<2>(str.param); + float iexval = std::get<3>(str.param); + // index with extreme value jexval. + gtint_t j = std::get<4>(str.param); + float jexval = std::get<5>(str.param); +#ifdef TEST_BLAS + std::string str_name = "snrm2_"; +#elif TEST_CBLAS + std::string str_name = "cblas_snrm2"; +#else //#elif TEST_BLIS_TYPED + std::string str_name = "bli_snormfv"; +#endif + str_name = str_name + "_" + std::to_string(n); + std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); + str_name = str_name + "_" + incx_str; + str_name = str_name + "_i" + std::to_string(i); + std::string iexval_str = testinghelpers::get_value_string(iexval); + str_name = str_name + "_" + iexval_str; + str_name = str_name + "_j" + std::to_string(j); + std::string jexval_str = testinghelpers::get_value_string(jexval); + str_name = str_name + "_" + jexval_str; + return str_name; + } +}; + +static float NaN = std::numeric_limits::quiet_NaN(); +static float Inf = std::numeric_limits::infinity(); + +/** + * Note: snrm2 scalar ONLY implementation is used, but we write the test + * using values that worked for the vectorized path for the future. + * + * scnrm2 implementation is composed by two parts: + * - vectorized path for n>=64 + * - for-loop for multiples of 32 (F32) + * - for-loop for multiples of 24 (F24) + * - for-loop for multiples of 16 (F16) + * - scalar path for n<64 (S) +*/ + +// Test for scalar path. +// Testing for jexval=1.0, means that we test only one NaN/Inf value. +// for jexval also being an extreme value, we test all combinations +// of having first a NaN and then an Inf and so on. +INSTANTIATE_TEST_SUITE_P( + scalar, + snrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(3)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(0), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(2), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::snrm2_TestPrint() + ); + +INSTANTIATE_TEST_SUITE_P( + vector_F32, + snrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(64)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(13), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(26), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::snrm2_TestPrint() + ); + +// To test the second for-loop (F24), we use n = 88 = 2*32+24 +// and ensure that the extreme values are on or after index 64. +INSTANTIATE_TEST_SUITE_P( + vector_F24, + snrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(88)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(70), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(80), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::snrm2_TestPrint() + ); + +// To test the second for-loop (F16), we use n = 80 = 2*32+16 +// and ensure that the extreme values are on or after index 64. +INSTANTIATE_TEST_SUITE_P( + vector_F16, + snrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(80)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(70), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(75), + ::testing::Values(1.0, NaN, Inf, -Inf) + ), + ::snrm2_TestPrint() + ); + +// Now let's check the combination of a vectorized path and +// the scalar path, by putting an extreme value in each +// to check that the checks are integrated correctly. +INSTANTIATE_TEST_SUITE_P( + vector_scalar, + snrm2_EVT, + ::testing::Combine( + // m size of vector + ::testing::Values(gtint_t(68)), + // stride size for x + ::testing::Values(gtint_t(1)), + // i : index of x that has value iexval + ::testing::Values(5), + // iexval + ::testing::Values(NaN, Inf, -Inf), + ::testing::Values(65), + ::testing::Values(NaN, Inf, -Inf) + ), + ::snrm2_TestPrint() + ); + diff --git a/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp index e23bc0d90c..eac411d12d 100644 --- a/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp +++ b/gtestsuite/testsuite/util/nrm2/snrm2_generic.cpp @@ -36,7 +36,7 @@ #include "test_nrm2.h" class snrm2Test : - public ::testing::TestWithParam> {}; + public ::testing::TestWithParam> {}; TEST_P( snrm2Test, RandomData ) { @@ -49,8 +49,6 @@ TEST_P( snrm2Test, RandomData ) gtint_t n = std::get<0>(GetParam()); // stride size for x: gtint_t incx = std::get<1>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<2>(GetParam()); // Set the threshold for the errors: double thresh = 2*n*testinghelpers::getEpsilon(); @@ -58,17 +56,16 @@ TEST_P( snrm2Test, RandomData ) //---------------------------------------------------------- // Call test body using these parameters //---------------------------------------------------------- - test_nrm2(n, incx, thresh, datatype); + test_nrm2( n, incx, thresh ); } // Prints the test case combination class snrm2TestPrint { public: std::string operator()( - testing::TestParamInfo> str) const { + testing::TestParamInfo> str) const { gtint_t n = std::get<0>(str.param); gtint_t incx = std::get<1>(str.param); - char datatype = std::get<2>(str.param); #ifdef TEST_BLAS std::string str_name = "snrm2_"; #elif TEST_CBLAS @@ -79,23 +76,44 @@ class snrm2TestPrint { str_name = str_name + "_" + std::to_string(n); std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + datatype; return str_name; } }; -// Black box testing. +/** + * Note: snrm2 scalar ONLY implementation is used, but we write the test + * using values that worked for the vectorized path for the future. + * + * scnrm2 implementation is composed by two parts: + * - vectorized path for n>=64 + * - for-loop for multiples of 32 (F32) + * - for-loop for multiples of 24 (F24) + * - for-loop for multiples of 16 (F16) + * - scalar path for n<64 (S) +*/ INSTANTIATE_TEST_SUITE_P( - Blackbox, + AT, snrm2Test, ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(2) + // m size of vector + ::testing::Values(gtint_t(1), // trivial case n=1 + gtint_t(35), // will only go through S + gtint_t(64), // 2*32 - will only go through F32 + gtint_t(76), // 2*32 + 12 - will go through F32 & S + gtint_t(80), // 2*32 + 16 - will go through F32 & F16 + gtint_t(85), // 2*32 + 16 + 5 - will go through F32 & F16 & S + gtint_t(88), // 2*32 + 24 - will go through F32 & F24 + gtint_t(91), // 2*32 + 24 + 3 - will go through F32 & F24 & S + gtint_t(124), // a few bigger numbers + gtint_t(167), + gtint_t(259) + ), + // stride size for x + ::testing::Values(gtint_t(1), gtint_t(3) #ifndef TEST_BLIS_TYPED - ,gtint_t(-1), gtint_t(-2) + , gtint_t(-1), gtint_t(-5) #endif - ), // stride size for x - ::testing::Values('i') // i : integer, f : float datatype type tested + ) // stride size for x ), ::snrm2TestPrint() ); diff --git a/gtestsuite/testsuite/util/nrm2/test_nrm2.h b/gtestsuite/testsuite/util/nrm2/test_nrm2.h index 2c9de86dc4..def4551929 100644 --- a/gtestsuite/testsuite/util/nrm2/test_nrm2.h +++ b/gtestsuite/testsuite/util/nrm2/test_nrm2.h @@ -35,33 +35,67 @@ #pragma once #include "nrm2.h" +#include #include "util/ref_nrm2.h" #include "inc/check_error.h" +// Used for generic tests with random values in x. template -void test_nrm2( gtint_t n, gtint_t incx, double thresh, char datatype ) +void test_nrm2( gtint_t n, gtint_t incx, double thresh ) { + // Get real type from T. + using RT = typename testinghelpers::type_info::real_type; //---------------------------------------------------------- // Initialize vectors with random numbers. //---------------------------------------------------------- - std::vector x( testinghelpers::buff_dim(n, incx) ); - testinghelpers::datagenerators::randomgenerators( -10, 10, n, incx, x.data(), datatype ); - + std::vector x = testinghelpers::get_random_vector( -10, -10, n, incx ); + //---------------------------------------------------------- // Call reference implementation to get ref results. //---------------------------------------------------------- - // Create a copy of y so that we can check reference results. - using real = typename testinghelpers::type_info::real_type; - real norm_ref = testinghelpers::ref_nrm2( n, x.data(), incx ); + RT norm_ref = testinghelpers::ref_nrm2( n, x.data(), incx ); //---------------------------------------------------------- // Call BLIS function. //---------------------------------------------------------- - real norm = nrm2(n, x.data(), incx); + RT norm = nrm2(n, x.data(), incx); //---------------------------------------------------------- // Compute error. //---------------------------------------------------------- - computediff( norm, norm_ref, thresh ); + computediff( norm, norm_ref, thresh ); } +// Test body used for extreme value testing, where we want to test +// cases where two extreme values are present. +// i is the index with corresponding extreme value iexval. +// j is the index with corresponding extreme value jexval. +template +void test_nrm2( gtint_t n, gtint_t incx, gtint_t i, T iexval, gtint_t j = 0, T jexval = T{1.0}) +{ + // Get real type from T. + using RT = typename testinghelpers::type_info::real_type; + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + std::vector x = testinghelpers::get_random_vector(-10, 10, n, incx); + // Initialize ith element of vector x to iexval. + x[i*std::abs(incx)] = iexval; + // Initialize jth element of vector x to jexval. + x[j*std::abs(incx)] = jexval; + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + RT norm_ref = testinghelpers::ref_nrm2( n, x.data(), incx ); + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + RT norm = nrm2(n, x.data(), incx); + + //---------------------------------------------------------- + // Compute error. + //---------------------------------------------------------- + // Compare using NaN/Inf checks. + computediff( norm, norm_ref, true ); +} diff --git a/gtestsuite/testsuite/util/nrm2/znrm2_generic.cpp b/gtestsuite/testsuite/util/nrm2/znrm2_generic.cpp deleted file mode 100644 index 55c1b9be07..0000000000 --- a/gtestsuite/testsuite/util/nrm2/znrm2_generic.cpp +++ /dev/null @@ -1,101 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include "test_nrm2.h" - -class znrm2Test : - public ::testing::TestWithParam> {}; - -TEST_P( znrm2Test, RandomData ) -{ - using T = dcomplex; - //---------------------------------------------------------- - // Initialize values from the parameters passed through - // test suite instantiation (INSTANTIATE_TEST_SUITE_P). - //---------------------------------------------------------- - // vector length: - gtint_t n = std::get<0>(GetParam()); - // stride size for x: - gtint_t incx = std::get<1>(GetParam()); - // specifies the datatype for randomgenerators - char datatype = std::get<2>(GetParam()); - - // Set the threshold for the errors: - double thresh = testinghelpers::getEpsilon(); - - //---------------------------------------------------------- - // Call test body using these parameters - //---------------------------------------------------------- - test_nrm2(n, incx, thresh, datatype); -} - -// Prints the test case combination -class znrm2TestPrint { -public: - std::string operator()( - testing::TestParamInfo> str) const { - gtint_t n = std::get<0>(str.param); - gtint_t incx = std::get<1>(str.param); - char datatype = std::get<2>(str.param); -#ifdef TEST_BLAS - std::string str_name = "dznrm2_"; -#elif TEST_CBLAS - std::string str_name = "cblas_dznrm2"; -#else //#elif TEST_BLIS_TYPED - std::string str_name = "bli_znormfv"; -#endif - str_name = str_name + "_" + std::to_string(n); - std::string incx_str = ( incx > 0) ? std::to_string(incx) : "m" + std::to_string(std::abs(incx)); - str_name = str_name + "_" + incx_str; - str_name = str_name + "_" + datatype; - return str_name; - } -}; - -// Black box testing. -INSTANTIATE_TEST_SUITE_P( - Blackbox, - znrm2Test, - ::testing::Combine( - ::testing::Range(gtint_t(10), gtint_t(101), 10), // m size of vector takes values from 10 to 100 with step size of 10. - ::testing::Values(gtint_t(1), gtint_t(2) -#ifndef TEST_BLIS_TYPED - ,gtint_t(-1), gtint_t(-2) -#endif - ), // stride size for x - ::testing::Values('i') // i : integer, f : float datatype type tested - ), - ::znrm2TestPrint() - ); diff --git a/kernels/CMakeLists.txt b/kernels/CMakeLists.txt index 47501d920c..fa15654125 100644 --- a/kernels/CMakeLists.txt +++ b/kernels/CMakeLists.txt @@ -1,10 +1,79 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## -add_subdirectory(haswell) -add_subdirectory(zen) +# Writing a function that will be used to generate the required object +# libraries for the required kernels. +function(generate_kernel_targets kernel_target) + # Collect all subdirectory paths that have at least one file with suffix in KERNELS_SRC_SUFS list. + get_filepaths_with_suffixes(LOCAL_SOURCE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/${kernel_target}" "${KERNELS_SRC_SUFS}") -if(${TARGET_ARCH} STREQUAL zen4 OR - ${TARGET_ARCH} STREQUAL amdzen) - add_subdirectory(skx) - add_subdirectory(zen4) -endif() + # Choose correct sub-configurarion name for the given kernel set. + get_config_for_kernel_from_kconfig_map(LOCAL_CONFIG ${kernel_target} "${KCONFIG_MAP}") + + # Only generate the object library if there is at least one source file. + list(LENGTH LOCAL_SOURCE_FILES size) + if(size GREATER 0) + # Create an object library using the source file list above. + add_library(${kernel_target}_KERNELS + OBJECT + ${LOCAL_SOURCE_FILES} + ) + # Include the corresponding make_defs.cmake that holds the required compiler options. + include(${CMAKE_SOURCE_DIR}/config/${LOCAL_CONFIG}/make_defs.cmake) + # Use PRIVATE keyword for option setting since we do not want the properties to propagate in other targets. + # mimicing get-kernel-cflags-for + target_compile_options(${kernel_target}_KERNELS + PRIVATE + # load-var-for,CKOPTFLAGS + ${CKOPTFLAGS} + # load-var-for,CKVECFLAGS + ${CKVECFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + # get-noopt-cflags-for + ${CWARNFLAGS} + # get-noopt-cflags-for + ${CMISCFLAGS} + # get-noopt-cflags-for + ${CLANGFLAGS} + # in get-kernel-cflags-for + ${COMPSIMDFLAGS} + # in get-kernel-cflags-for + ${BUILD_SYMFLAGS} + ) + target_compile_definitions(${kernel_target}_KERNELS + PRIVATE + # in get-noopt-cflags-for + ${CPPROCFLAGS} + # in get-noopt-cflags-for + ${VERS_DEF} + # in get-kernel-cflags-for + ${BUILD_CPPFLAGS} + ) + target_include_directories(${kernel_target}_KERNELS + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + ) + if(THREADING_MODEL STREQUAL "openmp") + # Equivalent to CTHREADFLAGS in get-noopt-cflags-for + target_link_libraries(${kernel_target}_KERNELS PRIVATE OpenMP::OpenMP_C) + elseif(THREADING_MODEL STREQUAL "pthreads") + # in get-noopt-cflags-for + target_compile_options(${kernel_target}_KERNELS PRIVATE ${CTHREADFLAGS}) + endif() + if(BUILD_SHARED_LIBS) + # Equivalent to CPICFLAGS in get-noopt-cflags-for + set_target_properties(${kernel_target}_KERNELS PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + add_dependencies(${kernel_target}_KERNELS flat-header) + # Put all those targets under object-libs-targets folder name so that they appear all together in IDE. + set_target_properties(${kernel_target}_KERNELS PROPERTIES FOLDER object-libs-targets) + endif() +endfunction() + +# Generate targets for each of the kernels present +# in the kernel list. +foreach(KERN ${KERNEL_LIST}) + generate_kernel_targets(${KERN}) +endforeach() diff --git a/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h b/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h new file mode 100644 index 0000000000..31dd5704ab --- /dev/null +++ b/kernels/armsve/1m/armsve512_asm_transpose_d8x2.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#define SVE512_IN_REG_TRANSPOSE_d8x2(DST0,DST1,DST2,DST3,DST4,DST5,DST6SRC0,DST7SRC1,PT,P2C,P4C,P6C) \ + "trn1 " #DST0".d, " #DST6SRC0".d, " #DST7SRC1".d \n\t" \ + "trn2 " #DST1".d, " #DST6SRC0".d, " #DST7SRC1".d \n\t" \ + "compact " #DST2".d, " #P2C", " #DST0".d \n\t" \ + "compact " #DST3".d, " #P2C", " #DST1".d \n\t" \ + "compact " #DST4".d, " #P4C", " #DST0".d \n\t" \ + "compact " #DST5".d, " #P4C", " #DST1".d \n\t" \ + "compact " #DST6SRC0".d, " #P6C", " #DST0".d \n\t" \ + "compact " #DST7SRC1".d, " #P6C", " #DST1".d \n\t" + diff --git a/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h b/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h new file mode 100644 index 0000000000..98426c9476 --- /dev/null +++ b/kernels/armsve/1m/armsve512_asm_transpose_d8x8.h @@ -0,0 +1,97 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#define SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(XTMP,PT,P2C,P4C,P6C,PTFTF,P4,P6) \ + "ptrue " #PT".d \n\t" \ + "mov " #XTMP", #2 \n\t" \ + "whilelo " #P2C".d, xzr, " #XTMP" \n\t" \ + "mov " #XTMP", #4 \n\t" \ + "whilelo " #P4".d, xzr, " #XTMP" \n\t" \ + "mov " #XTMP", #6 \n\t" \ + "whilelo " #P6".d, xzr, " #XTMP" \n\t" \ + \ + "eor " #PTFTF".b, " #PT"/z, " #P6".b, " #P4".b \n\t" /***** o o | o */ \ + "orr " #PTFTF".b, " #PT"/z, " #PTFTF".b, " #P2C".b \n\t" /* | o | o */ \ + \ + "not " #P2C".b, " #PT"/z, " #P2C".b \n\t" \ + "not " #P4C".b, " #PT"/z, " #P4".b \n\t" \ + "not " #P6C".b, " #PT"/z, " #P6".b \n\t" \ + +#define SVE512_IN_REG_TRANSPOSE_d8x8(DST0,DST1,DST2,DST3,DST4,DST5,DST6,DST7,SRC0,SRC1,SRC2,SRC3,SRC4,SRC5,SRC6,SRC7,PT,P2C,P4C,P6C,PTFTF,P4,P6) \ + "trn1 " #DST0".d, " #SRC0".d, " #SRC1".d \n\t" \ + "trn2 " #DST1".d, " #SRC0".d, " #SRC1".d \n\t" \ + "trn1 " #DST2".d, " #SRC2".d, " #SRC3".d \n\t" \ + "trn2 " #DST3".d, " #SRC2".d, " #SRC3".d \n\t" \ + "trn1 " #DST4".d, " #SRC4".d, " #SRC5".d \n\t" \ + "trn2 " #DST5".d, " #SRC4".d, " #SRC5".d \n\t" \ + "trn1 " #DST6".d, " #SRC6".d, " #SRC7".d \n\t" \ + "trn2 " #DST7".d, " #SRC6".d, " #SRC7".d \n\t" \ + \ + "compact " #SRC0".d, " #P2C", " #DST0".d \n\t" \ + "compact " #SRC2".d, " #P2C", " #DST1".d \n\t" \ + "ext " #SRC1".b, " #SRC1".b, " #DST2".b, #48 \n\t" \ + "ext " #SRC3".b, " #SRC3".b, " #DST3".b, #48 \n\t" \ + "compact " #SRC4".d, " #P2C", " #DST4".d \n\t" \ + "compact " #SRC6".d, " #P2C", " #DST5".d \n\t" \ + "ext " #SRC5".b, " #SRC5".b, " #DST6".b, #48 \n\t" \ + "ext " #SRC7".b, " #SRC7".b, " #DST7".b, #48 \n\t" \ + \ + "sel " #DST0".d, " #PTFTF", " #DST0".d, " #SRC1".d \n\t" \ + "sel " #DST2".d, " #PTFTF", " #SRC0".d, " #DST2".d \n\t" \ + "sel " #DST1".d, " #PTFTF", " #DST1".d, " #SRC3".d \n\t" \ + "sel " #DST3".d, " #PTFTF", " #SRC2".d, " #DST3".d \n\t" \ + "sel " #DST4".d, " #PTFTF", " #DST4".d, " #SRC5".d \n\t" \ + "sel " #DST6".d, " #PTFTF", " #SRC4".d, " #DST6".d \n\t" \ + "sel " #DST5".d, " #PTFTF", " #DST5".d, " #SRC7".d \n\t" \ + "sel " #DST7".d, " #PTFTF", " #SRC6".d, " #DST7".d \n\t" \ + \ + "compact " #SRC0".d, " #P4C", " #DST0".d \n\t" \ + "compact " #SRC1".d, " #P4C", " #DST1".d \n\t" \ + "compact " #SRC2".d, " #P4C", " #DST2".d \n\t" \ + "compact " #SRC3".d, " #P4C", " #DST3".d \n\t" \ + "ext " #SRC4".b, " #SRC4".b, " #DST4".b, #32 \n\t" \ + "ext " #SRC5".b, " #SRC5".b, " #DST5".b, #32 \n\t" \ + "ext " #SRC6".b, " #SRC6".b, " #DST6".b, #32 \n\t" \ + "ext " #SRC7".b, " #SRC7".b, " #DST7".b, #32 \n\t" \ + \ + "sel " #DST0".d, " #P4", " #DST0".d, " #SRC4".d \n\t" \ + "sel " #DST1".d, " #P4", " #DST1".d, " #SRC5".d \n\t" \ + "sel " #DST2".d, " #P4", " #DST2".d, " #SRC6".d \n\t" \ + "sel " #DST3".d, " #P4", " #DST3".d, " #SRC7".d \n\t" \ + "sel " #DST4".d, " #P4", " #SRC0".d, " #DST4".d \n\t" \ + "sel " #DST5".d, " #P4", " #SRC1".d, " #DST5".d \n\t" \ + "sel " #DST6".d, " #P4", " #SRC2".d, " #DST6".d \n\t" \ + "sel " #DST7".d, " #P4", " #SRC3".d, " #DST7".d \n\t" + diff --git a/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c b/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c index 82def6df7b..a9b3d0af8a 100644 --- a/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c +++ b/kernels/armsve/1m/bli_dpackm_armsve256_asm_8xk.c @@ -52,15 +52,12 @@ void bli_dpackm_armsve256_asm_8xk dim_t cdim_, dim_t n_, dim_t n_max_, - void* restrict kappa_, - void* restrict a_, inc_t inca_, inc_t lda_, - void* restrict p_, inc_t ldp_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, cntx_t* restrict cntx ) { - double* a = ( double* )a_; - double* p = ( double* )p_; - double* kappa = ( double* )kappa_; const int64_t cdim = cdim_; const int64_t mnr = 8; const int64_t n = n_; diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c new file mode 100644 index 0000000000..851363a9e0 --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_10xk.c @@ -0,0 +1,365 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "armsve512_asm_transpose_d8x8.h" +#include "armsve512_asm_transpose_d8x2.h" + +// assumption: +// SVE vector length = 512 bits. + +void bli_dpackm_armsve512_asm_10xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 10; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + const bool gs = inca != 1 && lda != 1; + const bool unitk = bli_deq1( *kappa ); + +#ifdef _A64FX + if ( bli_cntx_schema_a_block(cntx) != bli_cntx_schema_b_panel(cntx) ) + { + // A twisted way to infer whether A or B is being packed. + if ( schema == bli_cntx_schema_a_block(cntx) ) + p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; + if ( schema == bli_cntx_schema_b_panel(cntx) ) + p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; + } +#endif + + if ( cdim == mnr && !gs && unitk ) + { + uint64_t n_mker = n / 8; + uint64_t n_left = n % 8; + __asm__ volatile ( + "mov x0, %[a] \n\t" + "mov x1, %[p] \n\t" + "mov x2, %[ldp] \n\t" + "mov x3, %[lda] \n\t" + "mov x4, %[inca] \n\t" + "cmp x4, #1 \n\t" + // Skips by sizeof(double). + "mov x8, #8 \n\t" + "madd x2, x2, x8, xzr \n\t" + "madd x3, x3, x8, xzr \n\t" + "madd x4, x4, x8, xzr \n\t" + // Loop constants. + "mov x8, %[n_mker] \n\t" + "mov x9, %[n_left] \n\t" + "ptrue p0.d \n\t" + "b.ne .AROWSTOR \n\t" + // A stored in columns. + " .ACOLSTOR: \n\t" + // Prefetch distance. + "mov x17, #8 \n\t" + "madd x17, x17, x3, xzr \n\t" +#ifdef _A64FX + // Disable hardware prefetch for A. + "mov x16, 0x6 \n\t" + "lsl x16, x16, #60 \n\t" + "orr x0, x0, x16 \n\t" +#endif + " .ACOLSTORMKER: \n\t" + "cmp x8, xzr \n\t" + "b.eq .ACOLSTORMKEREND \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ldr q1, [x0, #64] \n\t" + "ld1d z2.d, p0/z, [x5] \n\t" + "ldr q3, [x5, #64] \n\t" + "ld1d z4.d, p0/z, [x6] \n\t" + "ldr q5, [x6, #64] \n\t" + "ld1d z6.d, p0/z, [x7] \n\t" + "ldr q7, [x7, #64] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x0, x7, x3 \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z8.d, p0/z, [x0] \n\t" + "ldr q9, [x0, #64] \n\t" + "ld1d z10.d, p0/z, [x5] \n\t" + "ldr q11, [x5, #64] \n\t" + "ld1d z12.d, p0/z, [x6] \n\t" + "ldr q13, [x6, #64] \n\t" + "ld1d z14.d, p0/z, [x7] \n\t" + "ldr q15, [x7, #64] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + // Plain storage + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "st1d z2.d, p0, [x10] \n\t" + "str q3, [x10, #64] \n\t" + "st1d z4.d, p0, [x11] \n\t" + "str q5, [x11, #64] \n\t" + "st1d z6.d, p0, [x12] \n\t" + "str q7, [x12, #64] \n\t" + "st1d z8.d, p0, [x13] \n\t" + "str q9, [x13, #64] \n\t" + "st1d z10.d, p0, [x14] \n\t" + "str q11, [x14, #64] \n\t" + "st1d z12.d, p0, [x15] \n\t" + "str q13, [x15, #64] \n\t" + "st1d z14.d, p0, [x16] \n\t" + "str q15, [x16, #64] \n\t" + "add x1, x16, x2 \n\t" + // Realign and store. + // "ext z1.b, z1.b, z1.b, #16 \n\t" + // "ext z1.b, z1.b, z2.b, #48 \n\t" + // "ext z2.b, z2.b, z3.b, #16 \n\t" + // "ext z2.b, z2.b, z4.b, #32 \n\t" + // "ext z4.b, z4.b, z5.b, #16 \n\t" + // "ext z4.b, z4.b, z6.b, #16 \n\t" + // "ext z6.b, z6.b, z7.b, #16 \n\t" + // "ext z9.b, z9.b, z9.b, #16 \n\t" + // "ext z9.b, z9.b, z10.b, #48 \n\t" + // "ext z10.b, z10.b, z11.b, #16 \n\t" + // "ext z10.b, z10.b, z12.b, #32 \n\t" + // "ext z12.b, z12.b, z13.b, #16 \n\t" + // "ext z12.b, z12.b, z14.b, #16 \n\t" + // "ext z14.b, z14.b, z15.b, #16 \n\t" + // "st1d z0.d, p0, [x1] \n\t" + // "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + // "st1d z2.d, p0, [x1, #2, mul vl] \n\t" + // "st1d z4.d, p0, [x1, #3, mul vl] \n\t" + // "st1d z6.d, p0, [x1, #4, mul vl] \n\t" + // "add x1, x1, #320 \n\t" + // "st1d z8.d, p0, [x1] \n\t" + // "st1d z9.d, p0, [x1, #1, mul vl] \n\t" + // "st1d z10.d, p0, [x1, #2, mul vl] \n\t" + // "st1d z12.d, p0, [x1, #3, mul vl] \n\t" + // "st1d z14.d, p0, [x1, #4, mul vl] \n\t" + // "add x1, x1, #320 \n\t" + "add x0, x7, x3 \n\t" + "sub x8, x8, #1 \n\t" + "b .ACOLSTORMKER \n\t" + " .ACOLSTORMKEREND: \n\t" + " .ACOLSTORLEFT: \n\t" + "cmp x9, xzr \n\t" + "b.eq .UNITKDONE \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ldr q1, [x0, #64] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "add x0, x0, x3 \n\t" + "add x1, x1, x2 \n\t" + "sub x9, x9, #1 \n\t" + "b .ACOLSTORLEFT \n\t" + // A stored in rows. + " .AROWSTOR: \n\t" + // Prepare predicates for in-reg transpose. + SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(x16,p0,p1,p2,p3,p8,p4,p6) + " .AROWSTORMKER: \n\t" // X[10-16] for A here not P. Be careful. + "cmp x8, xzr \n\t" + "b.eq .AROWSTORMKEREND \n\t" + "add x10, x0, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "add x17, x16, x4 \n\t" + "add x18, x17, x4 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x10] \n\t" + "ld1d z2.d, p0/z, [x11] \n\t" + "ld1d z3.d, p0/z, [x12] \n\t" + "ld1d z4.d, p0/z, [x13] \n\t" + "ld1d z5.d, p0/z, [x14] \n\t" + "ld1d z6.d, p0/z, [x15] \n\t" + "ld1d z7.d, p0/z, [x16] \n\t" + "ld1d z22.d, p0/z, [x17] \n\t" + "ld1d z23.d, p0/z, [x18] \n\t" + // Transpose first 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z8,z9,z10,z11,z12,z13,z14,z15,z0,z1,z2,z3,z4,z5,z6,z7,p0,p1,p2,p3,p8,p4,p6) + // Transpose last 2 rows. + SVE512_IN_REG_TRANSPOSE_d8x2(z16,z17,z18,z19,z20,z21,z22,z23,p0,p1,p2,p3) + // Plain storage. + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z8.d, p0, [x1] \n\t" + "str q16, [x1, #64] \n\t" + "st1d z9.d, p0, [x10] \n\t" + "str q17, [x10, #64] \n\t" + "st1d z10.d, p0, [x11] \n\t" + "str q18, [x11, #64] \n\t" + "st1d z11.d, p0, [x12] \n\t" + "str q19, [x12, #64] \n\t" + "st1d z12.d, p0, [x13] \n\t" + "str q20, [x13, #64] \n\t" + "st1d z13.d, p0, [x14] \n\t" + "str q21, [x14, #64] \n\t" + "st1d z14.d, p0, [x15] \n\t" + "str q22, [x15, #64] \n\t" + "st1d z15.d, p0, [x16] \n\t" + "str q23, [x16, #64] \n\t" + "add x1, x16, x2 \n\t" + "add x0, x0, #64 \n\t" + "sub x8, x8, #1 \n\t" + "b .AROWSTORMKER \n\t" + " .AROWSTORMKEREND: \n\t" + "mov x4, %[inca] \n\t" // Restore unshifted inca. + "index z30.d, xzr, x4 \n\t" // Generate index. + "lsl x4, x4, #3 \n\t" // Shift again. + "lsl x5, x4, #3 \n\t" // Virtual column vl. + " .AROWSTORLEFT: \n\t" + "cmp x9, xzr \n\t" + "b.eq .UNITKDONE \n\t" + "add x6, x0, x5 \n\t" + "add x7, x6, x4 \n\t" + "ld1d z0.d, p0/z, [x0, z30.d, lsl #3] \n\t" + "ldr d1, [x6] \n\t" + "ldr d2, [x7] \n\t" + "trn1 v1.2d, v1.2d, v2.2d \n\t" + "st1d z0.d, p0, [x1] \n\t" + "str q1, [x1, #64] \n\t" + "add x1, x1, x2 \n\t" + "add x0, x0, #8 \n\t" + "sub x9, x9, #1 \n\t" + "b .AROWSTORLEFT \n\t" + " .UNITKDONE: \n\t" + "mov x0, #0 \n\t" + : + : [a] "r" (a), + [p] "r" (p), + [lda] "r" (lda), + [ldp] "r" (ldp), + [inca] "r" (inca), + [n_mker] "r" (n_mker), + [n_left] "r" (n_left) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14","x15", + "x16","x17","x18", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19","z20","z21","z22","z23", + // "z24","z25","z26","z27","z28","z29", + "z30","z31", + "p0", "p1", "p2", "p3", "p4", // "p5", + "p6", "p7", "p8" + ); + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c new file mode 100644 index 0000000000..9f943fcd66 --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_12xk.c @@ -0,0 +1,359 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Linaro Limited + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include + +#ifdef __ARM_FEATURE_SVE +#include +#else +#error "No Arm SVE intrinsics support in compiler" +#endif // __ARM_FEATURE_SVE + +// assumption: +// SVE vector length = 512 bits. +// TODO: +// 2-rows -> 3 vectors packing and use predicator only in odd num of rows to be packed. +// prefetching is needed. + +void bli_dpackm_armsve512_asm_12xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 12; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + + double* restrict alpha1 = a; + double* restrict alpha1_8 = alpha1 + 8 * inca; + double* restrict alpha1_p4 = alpha1 + 4 * inca; + double* restrict alpha1_m4 = alpha1 - 4 * inca; + double* restrict pi1 = p; + const svbool_t all_active = svptrue_b64(); + const svbool_t first_half_active = svwhilelt_b64(0, 4); + const svbool_t last_half_active = svnot_z(all_active, first_half_active); + svfloat64_t z_a0; + svfloat64_t z_a8; + svfloat64_t z_a8_lh; + svfloat64_t z_a16; + svuint64_t z_index; + + // creating index for gather/scatter + // with each element as: 0, 1*inca, 2*inca, 3*inca + z_index = svindex_u64( 0, inca * sizeof( double ) ); + + if ( cdim == mnr ) + { + if ( bli_deq1( *kappa ) ) + { + if ( inca == 1 ) // continous memory. packA style + { + dim_t k = n; + // 2 pack into 3 case. + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // load 12 continuous elments from *a, filling last half of z8. + z_a8_lh = svld1_f64( last_half_active, alpha1_m4 ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_f64( all_active, alpha1_p4 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + // line-by-line packing case. + for ( ; k != 0; --k ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // gather load from *a, filling last half of z8. + z_a8_lh = svld1_gather_u64offset_f64( last_half_active, alpha1_m4, z_index ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_gather_u64offset_f64( all_active, alpha1_p4, z_index ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + } + else // *kappa != 1.0 + { + // load kappa into vector + svfloat64_t z_kappa; + + z_kappa = svdup_f64( *kappa ); + + if ( inca == 1 ) // continous memory. packA style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // load 12 continuous elments from *a, filling last half of z8. + z_a8_lh = svld1_f64( last_half_active, alpha1_m4 ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_f64( all_active, alpha1_p4 ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + z_a16 = svmul_lane_f64( z_a16, z_kappa, 0 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // load 12 continuous elments from *a + z_a0 = svld1_f64( all_active, alpha1 ); + z_a8 = svld1_vnum_f64( first_half_active, alpha1, 1 ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + + // store them into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + else // gather/scatter load/store. packB style + { + dim_t k = n; + if ( ldp == mnr ) + { + for ( ; k > 1; k -= 2 ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // forward address - 0 to 1 + alpha1 += lda; + alpha1_p4 = alpha1 + 4 * inca; + alpha1_m4 = alpha1 - 4 * inca; + + // gather load from *a, filling last half of z8. + z_a8_lh = svld1_gather_u64offset_f64( last_half_active, alpha1_m4, z_index ); + z_a8 = svadd_f64_z( all_active, z_a8, z_a8_lh ); + z_a16 = svld1_gather_u64offset_f64( all_active, alpha1_p4, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + z_a16 = svmul_lane_f64( z_a16, z_kappa, 0 ); + + // stored packed data into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( all_active, pi1, 1, z_a8 ); + svst1_vnum_f64( all_active, pi1, 2, z_a16 ); + + // forward address - 1 to 0 + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += 2 * ldp; + } + } + for ( ; k != 0; --k ) + { + // gather load from *a + z_a0 = svld1_gather_u64offset_f64( all_active, alpha1, z_index ); + z_a8 = svld1_gather_u64offset_f64( first_half_active, alpha1_8, z_index ); + + // multiply by *kappa + z_a0 = svmul_lane_f64( z_a0, z_kappa, 0 ); + z_a8 = svmul_lane_f64( z_a8, z_kappa, 0 ); + + // scatter store into *p + svst1_f64( all_active, pi1, z_a0 ); + svst1_vnum_f64( first_half_active, pi1, 1, z_a8 ); + + alpha1 += lda; + alpha1_8 = alpha1 + 8 * inca; + pi1 += ldp; + } + } + } // end of if ( *kappa == 1.0 ) + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} diff --git a/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c new file mode 100644 index 0000000000..38fb0b9125 --- /dev/null +++ b/kernels/armsve/1m/bli_dpackm_armsve512_asm_16xk.c @@ -0,0 +1,363 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "armsve512_asm_transpose_d8x8.h" + +// assumption: +// SVE vector length = 512 bits. + +void bli_dpackm_armsve512_asm_16xk + ( + conj_t conja, + pack_t schema, + dim_t cdim_, + dim_t n_, + dim_t n_max_, + double* restrict kappa, + double* restrict a, inc_t inca_, inc_t lda_, + double* restrict p, inc_t ldp_, + cntx_t* restrict cntx + ) +{ + const int64_t cdim = cdim_; + const int64_t mnr = 16; + const int64_t n = n_; + const int64_t n_max = n_max_; + const int64_t inca = inca_; + const int64_t lda = lda_; + const int64_t ldp = ldp_; + const bool gs = inca != 1 && lda != 1; + const bool unitk = bli_deq1( *kappa ); + +#ifdef _A64FX + if ( bli_cntx_schema_a_block(cntx) != bli_cntx_schema_b_panel(cntx) ) + { + // A twisted way to infer whether A or B is being packed. + if ( schema == bli_cntx_schema_a_block(cntx) ) + p = ( (uint64_t)0x1 << 56 ) | (uint64_t)p; + if ( schema == bli_cntx_schema_b_panel(cntx) ) + p = ( (uint64_t)0x2 << 56 ) | (uint64_t)p; + } +#endif + + if ( cdim == mnr && !gs && unitk ) + { + uint64_t n_mker = n / 8; + uint64_t n_left = n % 8; + __asm__ volatile ( + "mov x0, %[a] \n\t" + "mov x1, %[p] \n\t" + "mov x2, %[ldp] \n\t" + "mov x3, %[lda] \n\t" + "mov x4, %[inca] \n\t" + "cmp x4, #1 \n\t" + // Skips by sizeof(double). + "mov x8, #8 \n\t" + "madd x2, x2, x8, xzr \n\t" + "madd x3, x3, x8, xzr \n\t" + "madd x4, x4, x8, xzr \n\t" + + // "mov x8, 0x8 \n\t" // Control#0 for A address. + // "mov x8, 0x24 \n\t" // Higher 6bit for Control#0: + // "lsl x8, x8, #58 \n\t" // Valid|Strong|Strong|Alloc|Load|Strong + // "orr x8, x8, x3 \n\t" // Stride. + // "msr S3_3_C11_C6_0, x8 \n\t" // Write system register. + + // Loop constants. + "mov x8, %[n_mker] \n\t" + "mov x9, %[n_left] \n\t" + "ptrue p0.d \n\t" + "b.ne .AROWSTOR \n\t" + // A stored in columns. + " .ACOLSTOR: \n\t" + // Prefetch distance. + "mov x17, #8 \n\t" + "madd x17, x17, x3, xzr \n\t" +#ifdef _A64FX + "mov x16, 0x6 \n\t" // Disable hardware prefetch for A. + "lsl x16, x16, #60 \n\t" + "orr x0, x0, x16 \n\t" +#endif + // "add x5, x0, x3 \n\t" + // "add x6, x5, x3 \n\t" + // "add x7, x6, x3 \n\t" + // "prfm PLDL1STRM, [x0] \n\t" + // "prfm PLDL1STRM, [x5] \n\t" + // "prfm PLDL1STRM, [x6] \n\t" + // "prfm PLDL1STRM, [x7] \n\t" + // "add x18, x7, x3 \n\t" + // "add x5, x18, x3 \n\t" + // "add x6, x5, x3 \n\t" + // "add x7, x6, x3 \n\t" + // "prfm PLDL1STRM, [x18] \n\t" + // "prfm PLDL1STRM, [x5] \n\t" + // "prfm PLDL1STRM, [x6] \n\t" + // "prfm PLDL1STRM, [x7] \n\t" + " .ACOLSTORMKER: \n\t" + "cmp x8, xzr \n\t" + "b.eq .ACOLSTORMKEREND \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t" + "ld1d z2.d, p0/z, [x5] \n\t" + "ld1d z3.d, p0/z, [x5, #1, mul vl] \n\t" + "ld1d z4.d, p0/z, [x6] \n\t" + "ld1d z5.d, p0/z, [x6, #1, mul vl] \n\t" + "ld1d z6.d, p0/z, [x7] \n\t" + "ld1d z7.d, p0/z, [x7, #1, mul vl] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x0, x7, x3 \n\t" + "add x5, x0, x3 \n\t" + "add x6, x5, x3 \n\t" + "add x7, x6, x3 \n\t" + "ld1d z8.d, p0/z, [x0] \n\t" + "ld1d z9.d, p0/z, [x0, #1, mul vl] \n\t" + "ld1d z10.d, p0/z, [x5] \n\t" + "ld1d z11.d, p0/z, [x5, #1, mul vl] \n\t" + "ld1d z12.d, p0/z, [x6] \n\t" + "ld1d z13.d, p0/z, [x6, #1, mul vl] \n\t" + "ld1d z14.d, p0/z, [x7] \n\t" + "ld1d z15.d, p0/z, [x7, #1, mul vl] \n\t" + "add x18, x17, x0 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x5 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x6 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "add x18, x17, x7 \n\t" + "prfm PLDL1STRM, [x18] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "st1d z2.d, p0, [x10] \n\t" + "st1d z3.d, p0, [x10, #1, mul vl] \n\t" + "st1d z4.d, p0, [x11] \n\t" + "st1d z5.d, p0, [x11, #1, mul vl] \n\t" + "st1d z6.d, p0, [x12] \n\t" + "st1d z7.d, p0, [x12, #1, mul vl] \n\t" + "st1d z8.d, p0, [x13] \n\t" + "st1d z9.d, p0, [x13, #1, mul vl] \n\t" + "st1d z10.d, p0, [x14] \n\t" + "st1d z11.d, p0, [x14, #1, mul vl] \n\t" + "st1d z12.d, p0, [x15] \n\t" + "st1d z13.d, p0, [x15, #1, mul vl] \n\t" + "st1d z14.d, p0, [x16] \n\t" + "st1d z15.d, p0, [x16, #1, mul vl] \n\t" + "add x0, x7, x3 \n\t" + "add x1, x16, x2 \n\t" + "sub x8, x8, #1 \n\t" + "b .ACOLSTORMKER \n\t" + " .ACOLSTORMKEREND: \n\t" + " .ACOLSTORLEFT: \n\t" + "cmp x9, xzr \n\t" + "b.eq .UNITKDONE \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x0, #1, mul vl] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "add x0, x0, x3 \n\t" + "add x1, x1, x2 \n\t" + "sub x9, x9, #1 \n\t" + "b .ACOLSTORLEFT \n\t" + // A stored in rows. + " .AROWSTOR: \n\t" + // Prepare predicates for in-reg transpose. + SVE512_IN_REG_TRANSPOSE_d8x8_PREPARE(x16,p0,p1,p2,p3,p8,p4,p6) + " .AROWSTORMKER: \n\t" // X[10-16] for A here not P. Be careful. + "cmp x8, xzr \n\t" + "b.eq .AROWSTORMKEREND \n\t" + "add x10, x0, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "ld1d z0.d, p0/z, [x0] \n\t" + "ld1d z1.d, p0/z, [x10] \n\t" + "ld1d z2.d, p0/z, [x11] \n\t" + "ld1d z3.d, p0/z, [x12] \n\t" + "ld1d z4.d, p0/z, [x13] \n\t" + "ld1d z5.d, p0/z, [x14] \n\t" + "ld1d z6.d, p0/z, [x15] \n\t" + "ld1d z7.d, p0/z, [x16] \n\t" + "add x5, x16, x4 \n\t" + "add x10, x5, x4 \n\t" + "add x11, x10, x4 \n\t" + "add x12, x11, x4 \n\t" + "add x13, x12, x4 \n\t" + "add x14, x13, x4 \n\t" + "add x15, x14, x4 \n\t" + "add x16, x15, x4 \n\t" + "ld1d z16.d, p0/z, [x5] \n\t" + "ld1d z17.d, p0/z, [x10] \n\t" + "ld1d z18.d, p0/z, [x11] \n\t" + "ld1d z19.d, p0/z, [x12] \n\t" + "ld1d z20.d, p0/z, [x13] \n\t" + "ld1d z21.d, p0/z, [x14] \n\t" + "ld1d z22.d, p0/z, [x15] \n\t" + "ld1d z23.d, p0/z, [x16] \n\t" + // Transpose first 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z8,z9,z10,z11,z12,z13,z14,z15,z0,z1,z2,z3,z4,z5,z6,z7,p0,p1,p2,p3,p8,p4,p6) + // Transpose last 8 rows. + SVE512_IN_REG_TRANSPOSE_d8x8(z24,z25,z26,z27,z28,z29,z30,z31,z16,z17,z18,z19,z20,z21,z22,z23,p0,p1,p2,p3,p8,p4,p6) + "add x10, x1, x2 \n\t" + "add x11, x10, x2 \n\t" + "add x12, x11, x2 \n\t" + "add x13, x12, x2 \n\t" + "add x14, x13, x2 \n\t" + "add x15, x14, x2 \n\t" + "add x16, x15, x2 \n\t" + "st1d z8.d, p0, [x1] \n\t" + "st1d z24.d, p0, [x1, #1, mul vl] \n\t" + "st1d z9.d, p0, [x10] \n\t" + "st1d z25.d, p0, [x10, #1, mul vl] \n\t" + "st1d z10.d, p0, [x11] \n\t" + "st1d z26.d, p0, [x11, #1, mul vl] \n\t" + "st1d z11.d, p0, [x12] \n\t" + "st1d z27.d, p0, [x12, #1, mul vl] \n\t" + "st1d z12.d, p0, [x13] \n\t" + "st1d z28.d, p0, [x13, #1, mul vl] \n\t" + "st1d z13.d, p0, [x14] \n\t" + "st1d z29.d, p0, [x14, #1, mul vl] \n\t" + "st1d z14.d, p0, [x15] \n\t" + "st1d z30.d, p0, [x15, #1, mul vl] \n\t" + "st1d z15.d, p0, [x16] \n\t" + "st1d z31.d, p0, [x16, #1, mul vl] \n\t" + "add x0, x0, #64 \n\t" + "add x1, x16, x2 \n\t" + "sub x8, x8, #1 \n\t" + "b .AROWSTORMKER \n\t" + " .AROWSTORMKEREND: \n\t" + "mov x4, %[inca] \n\t" // Restore unshifted inca. + "index z30.d, xzr, x4 \n\t" // Generate index. + "lsl x4, x4, #3 \n\t" // Shift again. + "lsl x5, x4, #3 \n\t" // Virtual column vl. + " .AROWSTORLEFT: \n\t" + "cmp x9, xzr \n\t" + "b.eq .UNITKDONE \n\t" + "add x6, x0, x5 \n\t" + "ld1d z0.d, p0/z, [x0, z30.d, lsl #3] \n\t" + "ld1d z1.d, p0/z, [x6, z30.d, lsl #3] \n\t" + "st1d z0.d, p0, [x1] \n\t" + "st1d z1.d, p0, [x1, #1, mul vl] \n\t" + "add x1, x1, x2 \n\t" + "add x0, x0, #8 \n\t" + "sub x9, x9, #1 \n\t" + "b .AROWSTORLEFT \n\t" + " .UNITKDONE: \n\t" + "mov x0, #0 \n\t" + : + : [a] "r" (a), + [p] "r" (p), + [lda] "r" (lda), + [ldp] "r" (ldp), + [inca] "r" (inca), + [n_mker] "r" (n_mker), + [n_left] "r" (n_left) + : "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10","x11","x12","x13","x14","x15", + "x16","x17","x18", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10","z11","z12","z13","z14","z15", + // "z16","z17","z18","z19","z20","z21","z22","z23", + // "z24","z25","z26","z27","z28","z29","z30","z31", + "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7" + ); + } + else // if ( cdim < mnr ) + { + bli_dscal2m_ex + ( + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + ( trans_t )conja, + cdim, + n, + kappa, + a, inca, lda, + p, 1, ldp, + cntx, + NULL + ); + + // if ( cdim < mnr ) + { + const dim_t i = cdim; + const dim_t m_edge = mnr - i; + const dim_t n_edge = n_max; + double* restrict p_edge = p + (i )*1; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } + } + + if ( n < n_max ) + { + const dim_t j = n; + const dim_t m_edge = mnr; + const dim_t n_edge = n_max - j; + double* restrict p_edge = p + (j )*ldp; + + bli_dset0s_mxn + ( + m_edge, + n_edge, + p_edge, 1, ldp + ); + } +} diff --git a/kernels/armsve/3/armsve_asm_2vx10.h b/kernels/armsve/3/armsve_asm_2vx10.h new file mode 100644 index 0000000000..8e37585cba --- /dev/null +++ b/kernels/armsve/3/armsve_asm_2vx10.h @@ -0,0 +1,191 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_FMLA2_LD1R(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV0,BADDR,8) \ + GEMM_FMLA2_LD1R(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV1,BADDR,9) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ + GEMM_FMLA2_LD1R(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV2,BADDR,0) \ + GEMM_FMLA2_LD1R(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV3,BADDR,1) \ + GEMM_FMLA2_LD1R(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV4,BADDR,2) \ + GEMM_FMLA2_LD1R(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV5,BADDR,3) \ + GEMM_FMLA2_LD1R(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV6,BADDR,4) \ + GEMM_FMLA2_LD1R(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV7,BADDR,5) \ + \ + GEMM_FMLA2_LD1R(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV0,BADDR,6) \ + GEMM_FMLA2_LD1R(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV1,BADDR,7) + +// Second through forth microkernels are the first one with B vectors rotated. +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_2(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV2,BV3,BV4,BV5,BV6,BV7,BV0,BV1,BADDR,BRSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_3(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BADDR,BRSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_4(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_C_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV6,BV7,BV0,BV1,BV2,BV3,BV4,BV5,BADDR,BRSBIT) +// NOTE: +// The microkernel (PLAIN_1-4 as a whole) satisfies on entry/exit +// (sth. akin to loop-invariant): +// - BV[0-7] holds B[0:7, 4*k_cur] +// - B's address stops at B[0, 4*k_cur+1] + +// Final loop inside K=4 microkernels. +#define GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BRSBIT) \ + GEMM_FMLA2_LD1R(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV6,BADDR,8) \ + GEMM_FMLA2_LD1R(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV7,BADDR,9) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ + GEMM_FMLA2(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV0) \ + GEMM_FMLA2(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV1) \ + GEMM_FMLA2(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV2) \ + GEMM_FMLA2(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV3) \ + GEMM_FMLA2(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV4) \ + GEMM_FMLA2(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV5) \ + GEMM_FMLA2(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV6) \ + GEMM_FMLA2(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV7) + +// K=4 MKer loop with B memory scattered. +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV0,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV1,BELMADDR,BCSBIT) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ +" mov "#BELMADDR", "#BADDR" \n\t" \ + GEMM_FMLA2_LD1R_G_ELMFWD(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV2,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV3,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV4,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV5,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV6,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV7,BELMADDR,BCSBIT) \ + \ + GEMM_FMLA2_LD1R_G_ELMFWD(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV0,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV1,BELMADDR,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_2(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV2,BV3,BV4,BV5,BV6,BV7,BV0,BV1,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_3(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV4,BV5,BV6,BV7,BV0,BV1,BV2,BV3,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_4(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_2VX10_MKER_LOOP_PLAIN_G_1(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV6,BV7,BV0,BV1,BV2,BV3,BV4,BV5,BADDR,BELMADDR,BRSBIT,BCSBIT) + +#define GEMM_2VX10_MKER_LOOP_PLAIN_G_4_RESIDUAL(C0FH,C1FH,C2FH,C3FH,C4FH,C5FH,C6FH,C7FH,C8FH,C9FH,C0LH,C1LH,C2LH,C3LH,C4LH,C5LH,C6LH,C7LH,C8LH,C9LH,PT,ACOLFH,ACOLLH,BV0,BV1,BV2,BV3,BV4,BV5,BV6,BV7,BADDR,BELMADDR,BRSBIT,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C0FH,C0LH,PT,ACOLFH,ACOLLH,BV6,BELMADDR,BCSBIT) \ + GEMM_FMLA2_LD1R_G_ELMFWD(C1FH,C1LH,PT,ACOLFH,ACOLLH,BV7,BELMADDR,BCSBIT) \ +" add "#BADDR", "#BRSBIT", "#BADDR" \n\t" /* B address forward */ \ +" mov "#BELMADDR", "#BADDR" \n\t" \ + GEMM_FMLA2(C2FH,C2LH,PT,ACOLFH,ACOLLH,BV0) \ + GEMM_FMLA2(C3FH,C3LH,PT,ACOLFH,ACOLLH,BV1) \ + GEMM_FMLA2(C4FH,C4LH,PT,ACOLFH,ACOLLH,BV2) \ + GEMM_FMLA2(C5FH,C5LH,PT,ACOLFH,ACOLLH,BV3) \ + GEMM_FMLA2(C6FH,C6LH,PT,ACOLFH,ACOLLH,BV4) \ + GEMM_FMLA2(C7FH,C7LH,PT,ACOLFH,ACOLLH,BV5) \ + GEMM_FMLA2(C8FH,C8LH,PT,ACOLFH,ACOLLH,BV6) \ + GEMM_FMLA2(C9FH,C9LH,PT,ACOLFH,ACOLLH,BV7) + + +#define CLEAR_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19) \ + CLEAR_COL4(Z00,Z01,Z02,Z03) \ + CLEAR_COL4(Z04,Z05,Z06,Z07) \ + CLEAR_COL4(Z08,Z09,Z10,Z11) \ + CLEAR_COL4(Z12,Z13,Z14,Z15) \ + CLEAR_COL4(Z16,Z17,Z18,Z19) + +#define SCALE_COL20(Z00,Z01,Z02,Z03,Z04,Z05,Z06,Z07,Z08,Z09,Z10,Z11,Z12,Z13,Z14,Z15,Z16,Z17,Z18,Z19,ZFACTOR) \ + SCALE_COL4(Z00,Z01,Z02,Z03,ZFACTOR) \ + SCALE_COL4(Z04,Z05,Z06,Z07,ZFACTOR) \ + SCALE_COL4(Z08,Z09,Z10,Z11,ZFACTOR) \ + SCALE_COL4(Z12,Z13,Z14,Z15,ZFACTOR) \ + SCALE_COL4(Z16,Z17,Z18,Z19,ZFACTOR) + +#define GEMM_C_FMAD_UKER(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) + +#define GEMM_C_LOAD_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z0FH,Z0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z1FH,Z1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z2FH,Z2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z3FH,Z3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(Z4FH,Z4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_STORE_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z0FH,Z0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z1FH,Z1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z2FH,Z2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z3FH,Z3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_CONTIGUOUS_STORE_FWD(Z4FH,Z4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_FMAD_LOAD_UKER_C(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C0FH,C0LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C1FH,C1LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C2FH,C2LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C3FH,C3LH,PFH,PLH,CADDR,CCS) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) \ + GEMM_CCOL_CONTIGUOUS_LOAD_FWD(C4FH,C4LH,PFH,PLH,CADDR,CCS) + +#define GEMM_C_LOAD_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z0FH,Z0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z1FH,Z1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z2FH,Z2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z3FH,Z3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_GATHER_LOAD_FWD(Z4FH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + +#define GEMM_C_STORE_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z0FH,Z0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z1FH,Z1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z2FH,Z2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z3FH,Z3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_SCATTER_STORE_FWD(Z4FH,Z4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + +#define GEMM_C_FMAD_LOAD_UKER_G(Z0FH,Z1FH,Z2FH,Z3FH,Z4FH,Z0LH,Z1LH,Z2LH,Z3LH,Z4LH,PFH,PLH,C0FH,C1FH,C2FH,C3FH,C4FH,C0LH,C1LH,C2LH,C3LH,C4LH,ZSCALE,ZIDX,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z0FH,Z0LH,PFH,PLH,C0FH,C0LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C0FH,C0LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z1FH,Z1LH,PFH,PLH,C1FH,C1LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C1FH,C1LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z2FH,Z2LH,PFH,PLH,C2FH,C2LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C2FH,C2LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z3FH,Z3LH,PFH,PLH,C3FH,C3LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C3FH,C3LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_CCOL_FMAD(Z4FH,Z4LH,PFH,PLH,C4FH,C4LH,ZSCALE) \ + GEMM_CCOL_GATHER_LOAD_FWD(C4FH,C4LH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) + diff --git a/kernels/armsve/3/armsve_asm_macros.h b/kernels/armsve/3/armsve_asm_macros.h new file mode 100644 index 0000000000..5e8eb3c623 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#define CLEAR_COL2(Z0,Z1) \ +" dup "#Z0"."DT", #0 \n\t" \ +" dup "#Z1"."DT", #0 \n\t" + +#define CLEAR_COL4(Z0,Z1,Z2,Z3) \ + CLEAR_COL2(Z0,Z1) \ + CLEAR_COL2(Z2,Z3) + +#define SCALE_COL2(Z0,Z1,ZFACTOR) \ +" fmul "#Z0"."DT", "#Z0"."DT", "#ZFACTOR"."DT" \n\t" \ +" fmul "#Z1"."DT", "#Z1"."DT", "#ZFACTOR"."DT" \n\t" \ + +#define SCALE_COL4(Z0,Z1,Z2,Z3,ZFACTOR) \ + SCALE_COL2(Z0,Z1,ZFACTOR) \ + SCALE_COL2(Z2,Z3,ZFACTOR) + +// Prefetch or not. +#define PREFETCH_CONTIGUOUS_noprfm(LV,PROP,ADDR,SHIFT) +#define PREFETCH_CONTIGUOUS_prfm(LV,PROP,ADDR,SHIFT) \ +" prfm PLD"#LV""#PROP", ["#ADDR", "#SHIFT"] \n\t" + +#define GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" fmla "#CCOLFH"."DT", "#PT"/m, "#ACOLFH"."DT", "#BV"."DT" \n\t" /* A Row 0 :VL */ \ +" fmla "#CCOLLH"."DT", "#PT"/m, "#ACOLLH"."DT", "#BV"."DT" \n\t" /* A Row VL:2VL */ + +#define GEMM_FMLA2_LD1R(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BADDR,NSHIFT) \ + GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BADDR", #"#NSHIFT"*"SZ"]\n\t" + +#define GEMM_FMLA2_LD1R_G_ELMFWD(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV,BELMADDR,BCSBIT) \ + GEMM_FMLA2(CCOLFH,CCOLLH,PT,ACOLFH,ACOLLH,BV) \ +" "LD1R" "#BV"."DT", "#PT"/z, ["#BELMADDR"] \n\t" /* Load B */ \ +" add "#BELMADDR", "#BELMADDR", "#BCSBIT" \n\t" /* Forward B element */ + +#define GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,AADDR) \ +" "LD1" "#ZFH"."DT", "#PFH"/z, ["#AADDR"] \n\t" \ +" "LD1" "#ZLH"."DT", "#PLH"/z, ["#AADDR", #1, mul vl]\n\t" + +#define GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) \ +" "LD1" "#ZFH"."DT", "#PFH"/z, ["#AADDR", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#ATEMP", "#AADDR", "#AVSKIP" \n\t" \ +" "LD1" "#ZLH"."DT", "#PLH"/z, ["#ATEMP", "#ZIDX"."DT", "OFFS"]\n\t" + +// Prefetch or not. +#define GEMM_ACOL_GATHER_noprfm(LV,PROP,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) +#define GEMM_ACOL_GATHER_prfm(LV,PROP,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) \ +" "PRFG" PLD"#LV""#PROP", "#PFH", ["#AADDR", "#ZIDX"."DT", "OFFS"] \n\t" \ +" add "#ATEMP", "#AADDR", "#AVSKIP" \n\t" \ +" "PRFG" PLD"#LV""#PROP", "#PLH", ["#ATEMP", "#ZIDX"."DT", "OFFS"] \n\t" + +#define GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(ZFH,ZLH,PFH,PLH,AADDR,A4KS,ACS,ATEMP,PREFMODE) \ +" add "#ATEMP", "#AADDR", "#A4KS" \n\t" \ +" add "#AADDR", "#AADDR", "#ACS" \n\t" /* Forward A's address to the next column. */ \ + GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,AADDR) \ + PREFETCH_CONTIGUOUS_ ##PREFMODE(L1,STRM,ATEMP,0) + +#define GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,A4KS,APS,ACS,AVSKIP,ATEMP,PREFMODEL1,PREFMODEL2) \ +" add "#ATEMP", "#AADDR", "#A4KS" \n\t" \ + GEMM_ACOL_GATHER_ ##PREFMODEL1(L1,STRM,ZIDX,PFH,PLH,ATEMP,AVSKIP,ATEMP) \ +" add "#ATEMP", "#AADDR", "#APS" \n\t" \ + GEMM_ACOL_GATHER_ ##PREFMODEL2(L2,STRM,ZIDX,PFH,PLH,ATEMP,AVSKIP,ATEMP) \ +" add "#AADDR", "#AADDR", "#ACS" \n\t" /* Forward A's address to the next column. */ \ + GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,AADDR,AVSKIP,ATEMP) + +#define GEMM_CCOL_CONTIGUOUS_LOAD_FWD(ZFH,ZLH,PFH,PLH,CADDR,CCS) \ + GEMM_ACOL_CONTIGUOUS_LOAD(ZFH,ZLH,PFH,PLH,CADDR) \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" /* Forward C address (load) to next column. */ + +#define GEMM_CCOL_CONTIGUOUS_STORE_FWD(ZFH,ZLH,PFH,PLH,CADDR,CCS) \ +" "ST1" "#ZFH"."DT", "#PFH", ["#CADDR"] \n\t" \ +" "ST1" "#ZLH"."DT", "#PLH", ["#CADDR", #1, mul vl] \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" /* Forward C address (store) to next column. */ + +#define GEMM_CCOL_FMAD(ZFH,ZLH,PFH,PLH,CFH,CLH,ZSCALE) \ +" fmad "#ZFH"."DT", "#PFH"/m, "#ZSCALE"."DT", "#CFH"."DT" \n\t" \ +" fmad "#ZLH"."DT", "#PLH"/m, "#ZSCALE"."DT", "#CLH"."DT" \n\t" + +#define GEMM_CCOL_GATHER_LOAD_FWD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ + GEMM_ACOL_GATHER_LOAD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CVSKIP,CTEMP) \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + +#define GEMM_CCOL_SCATTER_STORE_FWD(ZFH,ZLH,ZIDX,PFH,PLH,CADDR,CCS,CVSKIP,CTEMP) \ +" "ST1" "#ZFH"."DT", "#PFH", ["#CADDR", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" "ST1" "#ZLH"."DT", "#PLH", ["#CTEMP", "#ZIDX"."DT", "OFFS"]\n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + + diff --git a/kernels/armsve/3/armsve_asm_macros_double.h b/kernels/armsve/3/armsve_asm_macros_double.h new file mode 100644 index 0000000000..f93d3f3821 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_double.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use double precision. +#define DT "d" +#define LD1 "ld1d" +#define ST1 "st1d" +#define LD1R "ld1rd" +#define PRFG "prfd" +#define SZ "8" +#define OFFS "lsl #3" +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/kernels/armsve/3/armsve_asm_macros_half.h b/kernels/armsve/3/armsve_asm_macros_half.h new file mode 100644 index 0000000000..9a46763ef2 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_half.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use half precision. +#define DT "h" +#define LD1 "ld1h" +#define ST1 "st1h" +#define LD1R "ld1rh" +#define PRFG "prfh" +#define SZ "2" +// #define OFFS UNSUPPORTED +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/kernels/armsve/3/armsve_asm_macros_single.h b/kernels/armsve/3/armsve_asm_macros_single.h new file mode 100644 index 0000000000..2203de3453 --- /dev/null +++ b/kernels/armsve/3/armsve_asm_macros_single.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +// Specify to use single precision. +#define DT "s" +#define LD1 "ld1w" +#define ST1 "st1w" +#define LD1R "ld1rw" +#define PRFG "prfw" +#define SZ "4" +#define OFFS "uxtw #2" +// Include macros. +#include "armsve_asm_macros.h" + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..5824d2d550 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,318 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019, Forschunszentrum Juelich + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Double-precision composite instructions. +#include "armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + +void bli_dgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incd x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" ptrue p0.d \n\t" +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" + +" ld1rd z20.d, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rd z21.d, p0/z, [x1, 8] \n\t" +" ld1rd z22.d, p0/z, [x1, 16] \n\t" +" ld1rd z23.d, p0/z, [x1, 24] \n\t" +" ld1rd z24.d, p0/z, [x1, 32] \n\t" +" ld1rd z25.d, p0/z, [x1, 40] \n\t" +" ld1rd z26.d, p0/z, [x1, 48] \n\t" +" ld1rd z27.d, p0/z, [x1, 56] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x8, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +" ld1rd z20.d, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rd z21.d, p0/z, [x1, 8] \n\t" +" ld1rd z22.d, p0/z, [x1, 16] \n\t" +" ld1rd z23.d, p0/z, [x1, 24] \n\t" +" ld1rd z24.d, p0/z, [x1, 32] \n\t" +" ld1rd z25.d, p0/z, [x1, 40] \n\t" +" ld1rd z26.d, p0/z, [x1, 48] \n\t" +" ld1rd z27.d, p0/z, [x1, 56] \n\t" +" ld1rd z28.d, p0/z, [x1, 64] \n\t" +" ld1rd z29.d, p0/z, [x1, 72] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr x4, [x4] \n\t" // Load alpha & beta (value). +" ldr x8, [x8] \n\t" +" dup z30.d, x4 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.d, x8 \n\t" +" fmov d28, #1.0 \n\t" // Prepare FP 1.0. +" fmov x16, d28 \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +#ifdef _A64FX +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" prfm PLDL1STRM, [x0] \n\t" +" prfm PLDL1STRM, [x0, 256*1] \n\t" +// " prfm PLDL2KEEP, [x0, 256*2] \n\t" +// " prfm PLDL2KEEP, [x0, 256*3] \n\t" +// " prfm PLDL2KEEP, [x0, 256*4] \n\t" +// " prfm PLDL2KEEP, [x0, 256*5] \n\t" +// " prfm PLDL2KEEP, [x0, 256*6] \n\t" +// " prfm PLDL2KEEP, [x0, 256*7] \n\t" +// " prfm PLDL2KEEP, [x0, 256*8] \n\t" +// " prfm PLDL2KEEP, [x0, 256*9] \n\t" +// " prfm PLDL2KEEP, [x0, 256*10] \n\t" +// " prfm PLDL2KEEP, [x0, 256*11] \n\t" +// " prfm PLDL2KEEP, [x0, 256*12] \n\t" +// " prfm PLDL2KEEP, [x0, 256*13] \n\t" +// " prfm PLDL2KEEP, [x0, 256*14] \n\t" +// " prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL1STRM, [x1] \n\t" +" prfm PLDL1STRM, [x1, 256*1] \n\t" +// " prfm PLDL2KEEP, [x1, 256*2] \n\t" +// " prfm PLDL2KEEP, [x1, 256*3] \n\t" +// " prfm PLDL2KEEP, [x1, 256*4] \n\t" +// " prfm PLDL2KEEP, [x1, 256*5] \n\t" +// " prfm PLDL2KEEP, [x1, 256*6] \n\t" +// " prfm PLDL2KEEP, [x1, 256*7] \n\t" +// " prfm PLDL2KEEP, [x1, 256*8] \n\t" +// " prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" // Preload first half of C for contiguous case. +" b.ne WRITE_MEM \n\t" +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" cmp x16, x4 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +// First half of C is already loaded in this case. +GEMM_C_FMAD_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x8, xzr \n\t" +" incb x8 \n\t" +" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x9,x7,x8,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x5,x7,x8,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c new file mode 100644 index 0000000000..8659e8b7ee --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -0,0 +1,307 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Single-precision composite instructions. +#include "armsve_asm_macros_single.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + +void bli_sgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" incw x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #4 \n\t" // Multiply some address skips by sizeof(float). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" ptrue p0.s \n\t" +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" + +" ld1rw z20.s, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rw z21.s, p0/z, [x1, 4] \n\t" +" ld1rw z22.s, p0/z, [x1, 8] \n\t" +" ld1rw z23.s, p0/z, [x1, 12] \n\t" +" ld1rw z24.s, p0/z, [x1, 16] \n\t" +" ld1rw z25.s, p0/z, [x1, 20] \n\t" +" ld1rw z26.s, p0/z, [x1, 24] \n\t" +" ld1rw z27.s, p0/z, [x1, 28] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x8, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +" ld1rw z20.s, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rw z21.s, p0/z, [x1, 4] \n\t" +" ld1rw z22.s, p0/z, [x1, 8] \n\t" +" ld1rw z23.s, p0/z, [x1, 12] \n\t" +" ld1rw z24.s, p0/z, [x1, 16] \n\t" +" ld1rw z25.s, p0/z, [x1, 20] \n\t" +" ld1rw z26.s, p0/z, [x1, 24] \n\t" +" ld1rw z27.s, p0/z, [x1, 28] \n\t" +" ld1rw z28.s, p0/z, [x1, 32] \n\t" +" ld1rw z29.s, p0/z, [x1, 36] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ldr w4, [x4] \n\t" // Load alpha & beta (value). +" ldr w8, [x8] \n\t" +" dup z30.s, w4 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.s, w8 \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL2KEEP, [x0] \n\t" +" prfm PLDL2KEEP, [x0, 256*1] \n\t" +" prfm PLDL2KEEP, [x0, 256*2] \n\t" +" prfm PLDL2KEEP, [x0, 256*3] \n\t" +" prfm PLDL2KEEP, [x0, 256*4] \n\t" +" prfm PLDL2KEEP, [x0, 256*5] \n\t" +" prfm PLDL2KEEP, [x0, 256*6] \n\t" +" prfm PLDL2KEEP, [x0, 256*7] \n\t" +" prfm PLDL2KEEP, [x0, 256*8] \n\t" +" prfm PLDL2KEEP, [x0, 256*9] \n\t" +" prfm PLDL2KEEP, [x0, 256*10] \n\t" +" prfm PLDL2KEEP, [x0, 256*11] \n\t" +" prfm PLDL2KEEP, [x0, 256*12] \n\t" +" prfm PLDL2KEEP, [x0, 256*13] \n\t" +" prfm PLDL2KEEP, [x0, 256*14] \n\t" +" prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov s28, #1.0 \n\t" +" fmov w16, s28 \n\t" +" cmp w16, w4 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x8, xzr \n\t" +" incb x8 \n\t" +" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x9,x7,x8,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x9,x7,x8,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p0,x5,x7,x8,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p0,x5,x7,x8,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_sh2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_sh2vx10_unindexed.c new file mode 100644 index 0000000000..817153bfe9 --- /dev/null +++ b/kernels/armsve/3/bli_gemm_armsve_asm_sh2vx10_unindexed.c @@ -0,0 +1,343 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + Copyright (C) 2019, Forschunszentrum Juelich + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" + +// Half-precision composite instructions. +#include "armsve_asm_macros_half.h" + +// 2vx10 microkernels. +#include "armsve_asm_2vx10.h" + +// Gather-load / scatter-store instruction for half-precision +// needs being defined separately. +#undef GEMM_CCOL_GATHER_LOAD_FWD +#undef GEMM_CCOL_SCATTER_STORE_FWD + +#define GEMM_CCOL_GATHER_LOAD_FWD(ZFH,ZLH,ZIDX2,PT,CRS2,CADDR,CCS,CVSKIP,CTEMP) \ +" add x28, "#CADDR", "#CRS2" \n\t" \ +" ld1h z31.s, "#PT"/z, ["#CADDR", "#ZIDX2".s, uxtw #1] \n\t" \ +" ld1h "#ZFH".s, "#PT"/z, [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZFH".s, "#PT"/m, "#ZFH".s \n\t" \ +" fadd "#ZFH".h, "#ZFH".h, z31.h \n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" add x28, "#CTEMP", "#CRS2" \n\t" \ +" ld1h z31.s, "#PT"/z, ["#CTEMP", "#ZIDX2".s, uxtw #1] \n\t" \ +" ld1h "#ZLH".s, "#PT"/z, [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZLH".s, "#PT"/m, "#ZLH".s \n\t" \ +" fadd "#ZLH".h, "#ZLH".h, z31.h \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + +#define GEMM_CCOL_SCATTER_STORE_FWD(ZFH,ZLH,ZIDX2,PT,CRS2,CADDR,CCS,CVSKIP,CTEMP) \ +" add x28, "#CADDR", "#CRS2" \n\t" \ +" st1h "#ZFH".s, "#PT", ["#CADDR", "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZFH".s, "#PT"/m, "#ZFH".s \n\t" \ +" st1h "#ZFH".s, "#PT", [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" add "#CTEMP", "#CADDR", "#CVSKIP" \n\t" \ +" add x28, "#CTEMP", "#CRS2" \n\t" \ +" st1h "#ZLH".s, "#PT", ["#CTEMP", "#ZIDX2".s, uxtw #1] \n\t" \ +" revh "#ZLH".s, "#PT"/m, "#ZLH".s \n\t" \ +" st1h "#ZLH".s, "#PT", [x28, "#ZIDX2".s, uxtw #1] \n\t" \ +" add "#CADDR", "#CADDR", "#CCS" \n\t" + + +void bli_shgemm_armsve_asm_2vx10_unindexed + ( + dim_t k0, + void* restrict alpha, + void* restrict a, + void* restrict b, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + __asm__ volatile ( +" ldr x0, %[a] \n\t" +" ldr x1, %[b] \n\t" +" mov x2, xzr \n\t" +" inch x2, ALL, MUL #2 \n\t" // Column-skip of A. +" mov x3, #10 \n\t" // Row-skip of B. +" \n\t" +" ldr x5, %[c] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x8, 0x3 \n\t" // Tag C address. +" lsl x8, x8, #56 \n\t" +" orr x5, x5, x8 \n\t" +" mov x8, 0x2 \n\t" // Tag B address. +" lsl x8, x8, #56 \n\t" +" orr x1, x1, x8 \n\t" +" mov x8, 0x1 \n\t" // Tag A address. +" lsl x8, x8, #56 \n\t" +" orr x0, x0, x8 \n\t" +#endif +" \n\t" +" mov x8, #2 \n\t" // Multiply some address skips by sizeof(float16_t). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" ptrue p0.b \n\t" +" \n\t" +" ldr x4, %[k_mker] \n\t" // Number of loops. +" ldr x8, %[k_left] \n\t" +" \n\t" +" LOAD_ABC: \n\t" +" cmp x4, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" + +" ld1rh z20.h, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rh z21.h, p0/z, [x1, 2] \n\t" +" ld1rh z22.h, p0/z, [x1, 4] \n\t" +" ld1rh z23.h, p0/z, [x1, 6] \n\t" +" ld1rh z24.h, p0/z, [x1, 8] \n\t" +" ld1rh z25.h, p0/z, [x1, 10] \n\t" +" ld1rh z26.h, p0/z, [x1, 12] \n\t" +" ld1rh z27.h, p0/z, [x1, 14] \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x4, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" \n\t" +" subs x4, x4, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" add x0, x0, x2 \n\t" // Forward A's address to the next column. +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p0,x0) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x1,x3) +" add x0, x0, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x8, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p0,p0,x0) +" ld1rh z20.h, p0/z, [x1] \n\t" // Load 8/10 of first B row. +" ld1rh z21.h, p0/z, [x1, 2] \n\t" +" ld1rh z22.h, p0/z, [x1, 4] \n\t" +" ld1rh z23.h, p0/z, [x1, 6] \n\t" +" ld1rh z24.h, p0/z, [x1, 8] \n\t" +" ld1rh z25.h, p0/z, [x1, 10] \n\t" +" ld1rh z26.h, p0/z, [x1, 12] \n\t" +" ld1rh z27.h, p0/z, [x1, 14] \n\t" +" ld1rh z28.h, p0/z, [x1, 16] \n\t" +" ld1rh z29.h, p0/z, [x1, 18] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x0, x0, x2 \n\t" // Forward A. +" add x1, x1, x3 \n\t" // Forward B. +" sub x8, x8, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x4, %[alpha] \n\t" // Load alpha & beta (address). +" ldr x8, %[beta] \n\t" +" ld1rh z30.h, p0/z, [x4] \n\t" // Load alpha & beta into vectors. +" ld1rh z31.h, p0/z, [x8] \n\t" +" fmov w4, h28 \n\t" // Copy alpha & beta to GP registers. +" fmov w8, h29 \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x0, %[a_next] \n\t" +" ldr x1, %[b_next] \n\t" +" prfm PLDL2KEEP, [x0] \n\t" +" prfm PLDL2KEEP, [x0, 256*1] \n\t" +" prfm PLDL2KEEP, [x0, 256*2] \n\t" +" prfm PLDL2KEEP, [x0, 256*3] \n\t" +" prfm PLDL2KEEP, [x0, 256*4] \n\t" +" prfm PLDL2KEEP, [x0, 256*5] \n\t" +" prfm PLDL2KEEP, [x0, 256*6] \n\t" +" prfm PLDL2KEEP, [x0, 256*7] \n\t" +" prfm PLDL2KEEP, [x0, 256*8] \n\t" +" prfm PLDL2KEEP, [x0, 256*9] \n\t" +" prfm PLDL2KEEP, [x0, 256*10] \n\t" +" prfm PLDL2KEEP, [x0, 256*11] \n\t" +" prfm PLDL2KEEP, [x0, 256*12] \n\t" +" prfm PLDL2KEEP, [x0, 256*13] \n\t" +" prfm PLDL2KEEP, [x0, 256*14] \n\t" +" prfm PLDL2KEEP, [x0, 256*15] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov h28, #1.0 \n\t" +" fmov w16, h28 \n\t" +" cmp w16, w4 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x10, xzr \n\t" +" incb x10 \n\t" +" madd x10, x10, x6, xzr \n\t" // C-column's logical 1-vector skip. +" mov x28, #2 \n\t" +" madd x6, x28, x6, xzr \n\t" // Double index skip for half-precision case. +" index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,x6,x9,x7,x10,x16) +" dup z31.h, w8 \n\t" // Restore beta destroyed by loading. +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p0,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,x6,x9,x7,x10,x16) +" \n\t" +" dup z31.h, w8 \n\t" // Restore beta destroyed by loading. +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,x6,x5,x7,x10,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p0,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,x6,x5,x7,x10,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" b END_EXEC \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [a] "m" (a), + [b] "m" (b), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x16","x10","x28", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); +} + diff --git a/kernels/armsve/3/sup/bli_gemmsup_armsve_ref.c b/kernels/armsve/3/sup/bli_gemmsup_armsve_ref.c new file mode 100644 index 0000000000..7d3d65b1f8 --- /dev/null +++ b/kernels/armsve/3/sup/bli_gemmsup_armsve_ref.c @@ -0,0 +1,450 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// Separate instantiation for ArmSVE reference kernels. +// Temporary workaround. Will be removed after upstream has switched to a better way +// of exposing gemmsup interface. + +// +// -- Row storage case --------------------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by rows. */ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict ci = &c[ i*rs_c ]; \ + ctype* restrict ai = &a[ i*rs_a ]; \ +\ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cij = &ci[ j*cs_c ]; \ + ctype* restrict bj = &b [ j*cs_b ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_r, _armsve, _ref2 ) + +// +// -- Column storage case ------------------------------------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf ) \ +\ +void PASTEMAC3(ch,opname,arch,suf) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict data, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* NOTE: This microkernel can actually handle arbitrarily large + values of m, n, and k. */ \ +\ + if ( bli_is_noconj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_noconj( conja ) && bli_is_conj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,axpyjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else if ( bli_is_conj( conja ) && bli_is_noconj( conjb ) ) \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dotjs)( *aij, *bij, ab ); \ + } \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_conj( conja ) && bli_is_conj( conjb ) ) */ \ + { \ + /* Traverse c by columns. */ \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + ctype* restrict cj = &c[ j*cs_c ]; \ + ctype* restrict bj = &b[ j*cs_b ]; \ +\ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + ctype* restrict cij = &cj[ i*rs_c ]; \ + ctype* restrict ai = &a [ i*rs_a ]; \ + ctype ab; \ +\ + PASTEMAC(ch,set0s)( ab ); \ +\ + /* Perform a dot product to update the (i,j) element of c. */ \ + for ( dim_t l = 0; l < k; ++l ) \ + { \ + ctype* restrict aij = &ai[ l*cs_a ]; \ + ctype* restrict bij = &bj[ l*rs_b ]; \ +\ + PASTEMAC(ch,dots)( *aij, *bij, ab ); \ + } \ +\ + /* Conjugate the result to simulate conj(a^T) * conj(b). */ \ + PASTEMAC(ch,conjs)( ab ); \ +\ + /* If beta is one, add ab into c. If beta is zero, overwrite c + with the result in ab. Otherwise, scale by beta and accumulate + ab to c. */ \ + if ( PASTEMAC(ch,eq1)( *beta ) ) \ + { \ + PASTEMAC(ch,axpys)( *alpha, ab, *cij ); \ + } \ + else if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + PASTEMAC(ch,scal2s)( *alpha, ab, *cij ); \ + } \ + else \ + { \ + PASTEMAC(ch,axpbys)( *alpha, ab, *beta, *cij ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC2( gemmsup_c, _armsve, _ref2 ) + diff --git a/kernels/armsve/3/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..3341b63d00 --- /dev/null +++ b/kernels/armsve/3/sup/bli_gemmsup_cv_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,528 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" +#include + +// Double-precision composite instructions. +#include "../armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "../armsve_asm_2vx10.h" + +// Prototype reference kernel. +GEMMSUP_KER_PROT( double, d, gemmsup_c_armsve_ref2 ) + +void __attribute__ ((noinline,optimize(0))) bli_dgemmsup_cv_armsve_2vx10_unindexed + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + static int called = 0; + if ( !called ) + { + fprintf(stderr, "rv called.\n"); + called = 1; + } + // c*c requires A to be stored in columns. + assert( rs_a0 == 1 ); + + dim_t n0_mker = n0 / 10; + dim_t n0_left = n0 % 10; + + if ( n0_left ) + { + // A[:, ::] + // B[::, n0_mker*10:n0] + // C[: , n0_mker*10:n0] + double *ai = a; + double *bi = b + n0_mker * 10 * cs_b0; + double *ci = c + n0_mker * 10 * cs_c0; + bli_dgemmsup_c_armsve_ref2 + ( + conja, conjb, + m0, n0_left, k0, + alpha, + ai, rs_a0, cs_a0, + bi, rs_b0, cs_b0, + beta, + ci, rs_c0, cs_c0, + data, + cntx + ); + } + // Return if it's a pure edge case. + if ( !n0_mker ) + return; + + // Determine VL. + uint64_t vlen2; + __asm__ ( + " mov x0, xzr \n\t" + " incd x0, ALL, MUL #2 \n\t" + " mov %[vlen2], x0 \n\t" + : [vlen2] "=r" (vlen2) + : + : "x0" + ); + + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + // uint64_t rs_a = 1; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t n_mker = n0_mker; + + dim_t m0_mker = m0 / vlen2; + dim_t m0_left = m0 % vlen2; + if ( m0_left ) + { + // Edge case on A side can be handled with one more (predicated) loop. + m0_mker++; + } else + m0_left = vlen2; + // uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_b = bli_auxinfo_ps_b( data ); + + for ( dim_t im0_mker = 0; im0_mker < m0_mker; ++im0_mker ) + { + uint64_t m_curr = vlen2; + if ( im0_mker == m0_mker - 1 ) + { + // Last m-loop. Maybe unnecessary. + m_curr = m0_left; + } + double *ai = a + im0_mker * vlen2 * rs_a0; + double *bi = b; + double *ci = c + im0_mker * vlen2 * rs_c0; + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + __asm__ volatile ( +" ldr x0, %[bi] \n\t" +" ldr x1, %[rs_b] \n\t" // Row-skip of B. +" ldr x2, %[cs_b] \n\t" // Column-skip of B (element skip of B[l, :]). +" ldr x3, %[ps_b] \n\t" // Panel-skip (10*k) of B. +" ldr x4, %[cs_a] \n\t" // Column-Skip of A. +" \n\t" // Element skip of A[:, l] is guaranteed to be 1. +" ldr x5, %[ci] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr x5, x5, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr x0, x0, x16 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x1, x8, x1, xzr \n\t" // rs_b +" madd x2, x8, x2, xzr \n\t" // cs_b +" madd x3, x8, x3, xzr \n\t" // ps_b +" madd x4, x8, x4, xzr \n\t" // cs_a +" madd x7, x8, x7, xzr \n\t" // cs_c +" mov x8, #4 \n\t" +" madd x15, x8, x4, xzr \n\t" // Logical K=4 microkernel skip for A. +" \n\t" +#ifdef _A64FX +" mov x16, 0x20 \n\t" // Higher 6bit for Control#2: +" lsl x16, x16, #58 \n\t" // Valid|Strong|Strong|NoAlloc|Load|Strong +" orr x16, x16, x4 \n\t" // Stride. +" msr S3_3_C11_C6_2, x16 \n\t" // Write system register. +#endif +" \n\t" +" ldr x8, %[m_curr] \n\t" // Size of first dimension. +" mov x9, xzr \n\t" +" incd x9 \n\t" +" ptrue p0.d \n\t" +" whilelo p1.d, xzr, x8 \n\t" +" whilelo p2.d, x9, x8 \n\t" +" \n\t" +" ldr x8, %[n_mker] \n\t" // Number of N-loops. +" \n\t" +" ldr x20, %[ai] \n\t" // Parameters to be reloaded +" ldr x21, %[k_mker] \n\t" // within each millikernel loop. +" ldr x22, %[k_left] \n\t" +" ldr x23, %[alpha] \n\t" +" ldr x24, %[beta] \n\t" +" ldr x25, %[a_next] \n\t" +" ldr x26, %[b_next] \n\t" +" ldr x23, [x23] \n\t" // Directly load alpha and beta. +" ldr x24, [x24] \n\t" +" \n\t" +" MILLIKER_MLOOP: \n\t" +" \n\t" +" mov x11, x0 \n\t" // B's address. +// " ldr x10, %[ai] \n\t" // A's address. +" mov x10, x20 \n\t" +// " ldr x12, %[k_mker] \n\t" +" mov x12, x21 \n\t" +// " ldr x13, %[k_left] \n\t" +" mov x13, x22 \n\t" +#ifdef _A64FX +" mov x16, 0x3 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr x10, x10, x16 \n\t" +" mov x16, 0xa \n\t" // Control#2 for A address. +" lsl x16, x16, #60 \n\t" +" orr x10, x10, x16 \n\t" +#endif +" \n\t" +" cmp x12, #0 \n\t" // Don't preload if no microkernel there. +" b.eq END_CCOL_PRFM \n\t" +" \n\t" +" mov x14, x11 \n\t" +" ld1rd z20.d, p0/z, [x14] \n\t" // Load 8/10 of first B row. +" add x14, x14, x2 \n\t" +" ld1rd z21.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z22.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z23.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z24.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z25.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z26.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z27.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" prfm PLDL1KEEP, [x14] \n\t" // And prefetch the 2/10 left. +" add x14, x14, x2 \n\t" +" prfm PLDL1KEEP, [x14] \n\t" +" sub x14, x14, x2 \n\t" // Restore x14 to load edge. +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p1,p2,x10) +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" // Prefetch 3/4 of A. +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x10, x4 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" \n\t" +" CCOL_PRFM: \n\t" +" cmp x6, #1 \n\t" +" b.ne END_CCOL_PRFM \n\t" // Do not prefetch for generic C storage. +" mov x16, x5 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" add x16, x16, x7 \n\t" +" prfm PLDL1STRM, [x16] \n\t" +" END_CCOL_PRFM: \n\t" +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x12, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z30,z31,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z28,z29,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z30,z31,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" \n\t" +" subs x12, x12, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_C(z28,z29,p1,p2,x10,x15,x4,x16,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_G_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_G_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x14,x1,x2) +" add x10, x10, x4 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x13, #0 \n\t" // End of execution. +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +GEMM_ACOL_CONTIGUOUS_LOAD(z30,z31,p1,p2,x10) +" mov x14, x11 \n\t" +" ld1rd z20.d, p0/z, [x14] \n\t" // Load 10/10 B. +" add x14, x14, x2 \n\t" +" ld1rd z21.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z22.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z23.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z24.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z25.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z26.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z27.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z28.d, p0/z, [x14] \n\t" +" add x14, x14, x2 \n\t" +" ld1rd z29.d, p0/z, [x14] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x10, x10, x4 \n\t" // Forward A. +" add x11, x11, x1 \n\t" // Forward B. +" sub x13, x13, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +// " ldr x10, %[ai] \n\t" +" mov x10, x20 \n\t" +" add x11, x0, x3 \n\t" +" dup z30.d, x23 \n\t" // Broadcast alpha & beta into vectors. +" dup z31.d, x24 \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.eq PREFETCH_ABNEXT \n\t" +" prfm PLDL1STRM, [x10] \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" add x11, x11, x2 \n\t" +" prfm PLDL1KEEP, [x11] \n\t" +" b WRITE_MEM \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +// " ldr x1, %[a_next] \n\t" // Final Millikernel loop, x1 and x2 not needed. +" mov x1, x25 \n\t" +// " ldr x2, %[b_next] \n\t" +" mov x2, x26 \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" prfm PLDL2KEEP, [x1, 256*10] \n\t" +" prfm PLDL2KEEP, [x1, 256*11] \n\t" +" prfm PLDL2KEEP, [x1, 256*12] \n\t" +" prfm PLDL2KEEP, [x1, 256*13] \n\t" +" prfm PLDL2KEEP, [x1, 256*14] \n\t" +" prfm PLDL2KEEP, [x1, 256*15] \n\t" +" prfm PLDL2KEEP, [x2] \n\t" +" prfm PLDL2KEEP, [x2, 256*1] \n\t" +" prfm PLDL2KEEP, [x2, 256*2] \n\t" +" prfm PLDL2KEEP, [x2, 256*3] \n\t" +" prfm PLDL2KEEP, [x2, 256*4] \n\t" +" prfm PLDL2KEEP, [x2, 256*5] \n\t" +" prfm PLDL2KEEP, [x2, 256*6] \n\t" +" prfm PLDL2KEEP, [x2, 256*7] \n\t" +" prfm PLDL2KEEP, [x2, 256*8] \n\t" +" prfm PLDL2KEEP, [x2, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" +" fmov x16, d28 \n\t" +" cmp x16, x23 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" \n\t" // C address for storing is x5 itself. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" mov x13, xzr \n\t" // C-column's physical 1-vector skip. +" incb x13 \n\t" +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x5,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x5,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x12, xzr \n\t" +" incb x12 \n\t" +" madd x13, x12, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x9,x7,x13,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x9,x7,x13,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x5,x7,x13,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x5,x7,x13,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" subs x8, x8, #1 \n\t" +" b.eq END_EXEC \n\t" +" \n\t" // Address of C already forwarded to next column. +" add x0, x0, x3 \n\t" // Forward B's base address to the next logic panel. +" b MILLIKER_MLOOP \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [bi] "m" (bi), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [ps_b] "m" (ps_b), + [cs_a] "m" (cs_a), + [ci] "m" (ci), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_curr] "m" (m_curr), + [n_mker] "m" (n_mker), + [ai] "m" (ai), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x10","x11","x12","x13","x14","x15","x16","x17", + "x20","x21","x22","x23","x24","x25","x26", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + } +} + +void bli_dgemmsup_rv_armsve_10x2v_unindexed + ( + conj_t conjat, + conj_t conjbt, + dim_t m0t, + dim_t n0t, + dim_t k0, + double* restrict alpha, + double* restrict at, inc_t rs_at0, inc_t cs_at0, + double* restrict bt, inc_t rs_bt0, inc_t cs_bt0, + double* restrict beta, + double* restrict ct, inc_t rs_ct0, inc_t cs_ct0, + auxinfo_t* restrict datat, + cntx_t* restrict cntx + ) +{ + auxinfo_t data; + bli_auxinfo_set_next_a( bli_auxinfo_next_b( datat ), &data ); + bli_auxinfo_set_next_b( bli_auxinfo_next_a( datat ), &data ); + bli_auxinfo_set_ps_a( bli_auxinfo_ps_b( datat ), &data ); + bli_auxinfo_set_ps_b( bli_auxinfo_ps_a( datat ), &data ); + bli_dgemmsup_cv_armsve_2vx10_unindexed + ( + conjbt, conjat, + n0t, m0t, k0, + alpha, + bt, cs_bt0, rs_bt0, + at, cs_at0, rs_at0, + beta, + ct, cs_ct0, rs_ct0, + &data, + cntx + ); +} + diff --git a/kernels/armsve/3/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c new file mode 100644 index 0000000000..6bcea73f5d --- /dev/null +++ b/kernels/armsve/3/sup/bli_gemmsup_rv_armsve_asm_d2vx10_unindexed.c @@ -0,0 +1,412 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ +#include "blis.h" +#include + +// Double-precision composite instructions. +#include "../armsve_asm_macros_double.h" + +// 2vx10 microkernels. +#include "../armsve_asm_2vx10.h" + +// Prototype reference kernel. +GEMMSUP_KER_PROT( double, d, gemmsup_r_armsve_ref2 ) + +void __attribute__ ((optimize(0))) bli_dgemmsup_rv_armsve_2vx10_unindexed + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + static int called = 0; + if ( !called ) + { + fprintf(stderr, "rv called.\n"); + called = 1; + } + // r*r requires B to be stored in rows. + assert(cs_b0 == 1); + + dim_t n0_mker = n0 / 10; + dim_t n0_left = n0 % 10; + + if ( n0_left ) + { + // A[:, ::] + // B[::, n0_mker*10:n0] + // C[: , n0_mker*10:n0] + double *ai = a; + double *bi = b + n0_mker * 10 * cs_b0; + double *ci = c + n0_mker * 10 * cs_c0; + bli_dgemmsup_r_armsve_ref2 + ( + conja, conjb, + m0, n0_left, k0, + alpha, + ai, rs_a0, cs_a0, + bi, rs_b0, cs_b0, + beta, + ci, rs_c0, cs_c0, + data, + cntx + ); + } + // Return if it's a pure edge case. + if ( !n0_mker ) + return; + + // Determine VL. + uint64_t vlen2; + __asm__ ( + " mov x0, xzr \n\t" + " incd x0, ALL, MUL #2 \n\t" + " mov %[vlen2], x0 \n\t" + : [vlen2] "=r" (vlen2) + : + : "x0" + ); + + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + // uint64_t cs_b = 1; + + uint64_t k_mker = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t m_mker = m0 / vlen2; + uint64_t m_left = m0 % vlen2; + if ( m_left ) + { + // Edge case on A side can be handled with one more (predicated) loop. + m_mker++; + } else + m_left = vlen2; + uint64_t ps_a = bli_auxinfo_ps_a( data ); + // uint64_t ps_b = bli_auxinfo_ps_b( data ); + + for ( dim_t in0_mker = 0; in0_mker < n0_mker; ++in0_mker ) + { + double *ai = a; + double *bi = b + in0_mker * 10 * cs_b0; + double *ci = c + in0_mker * 10 * cs_c0; + + void* a_next = bli_auxinfo_next_a( data ); + void* b_next = bli_auxinfo_next_b( data ); + + __asm__ volatile ( +" ldr x0, %[ai] \n\t" +" ldr x1, %[rs_a] \n\t" // Row-skip of A (element skip of A[:, l]). +" ldr x2, %[cs_a] \n\t" // Column-skip of A. +" ldr x3, %[ps_a] \n\t" // Panel-skip (vlen2*k) of A. +" ldr x4, %[rs_b] \n\t" // Row-Skip of B. +" \n\t" // Element skip of B[l, :] is guaranteed to be 1. +" ldr x5, %[ci] \n\t" +" ldr x6, %[rs_c] \n\t" // Row-skip of C. +" ldr x7, %[cs_c] \n\t" // Column-skip of C. +#ifdef _A64FX +" mov x16, 0x1 \n\t" // Tag C address. +" lsl x16, x16, #56 \n\t" +" orr x5, x5, x16 \n\t" +" mov x16, 0x2 \n\t" // Tag A address. +" lsl x16, x16, #56 \n\t" +" orr x0, x0, x16 \n\t" +#endif +" \n\t" +" mov x8, #8 \n\t" // Multiply some address skips by sizeof(double). +" madd x2, x8, x2, xzr \n\t" // cs_a +" madd x3, x8, x3, xzr \n\t" // ps_a +" madd x4, x8, x4, xzr \n\t" // rs_b +" madd x7, x8, x7, xzr \n\t" // cs_c +" mov x8, xzr \n\t" +" incb x8 \n\t" +" madd x14, x8, x1, xzr \n\t" // A-column's logical 1-vector skip. +" mov x8, #4 \n\t" +" madd x15, x8, x2, xzr \n\t" // Logical K=4 microkernel skip for A. +// " mov x8, #4 \n\t" +// " madd x17, x8, x4, xzr \n\t" // Logical K=4 microkernel skip for B. +" \n\t" +" ldr x8, %[m_mker] \n\t" // Number of M-loops. +" ptrue p0.d \n\t" +" ptrue p1.d \n\t" +" ptrue p2.d \n\t" +" \n\t" +" MILLIKER_MLOOP: \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.ne UKER_BEGIN \n\t" +" \n\t" +" ldr x10, %[m_left] \n\t" // Final (incomplete) millikernel loop. +" mov x11, xzr \n\t" +" incd x11 \n\t" +" whilelo p1.d, xzr, x10 \n\t" // Overwrite p1/p2. +" whilelo p2.d, x11, x10 \n\t" +" \n\t" +" UKER_BEGIN: \n\t" +" mov x10, x0 \n\t" // A's address. +" ldr x11, %[bi] \n\t" // B's address. +" ldr x12, %[k_mker] \n\t" +" ldr x13, %[k_left] \n\t" +#ifdef _A64FX +" mov x16, 0x3 \n\t" // Tag B address. +" lsl x16, x16, #56 \n\t" +" orr x11, x11, x16 \n\t" +#endif +" \n\t" +" mov x16, x11 \n\t" // Prefetch first kernel of B. +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" add x16, x16, x4 \n\t" +" prfm PLDL1KEEP, [x16] \n\t" +" \n\t" +" ld1rd z20.d, p0/z, [x11] \n\t" // (Partial) first B row. +" ld1rd z21.d, p0/z, [x11, #8] \n\t" +" ld1rd z22.d, p0/z, [x11, #16] \n\t" +" ld1rd z23.d, p0/z, [x11, #24] \n\t" +" ld1rd z24.d, p0/z, [x11, #32] \n\t" +" ld1rd z25.d, p0/z, [x11, #40] \n\t" +" ld1rd z26.d, p0/z, [x11, #48] \n\t" +" ld1rd z27.d, p0/z, [x11, #56] \n\t" +" \n\t" +" index z29.d, xzr, x1 \n\t" // First A column. +" \n\t" // Skips passed to index is not multiplied by 8. +GEMM_ACOL_GATHER_LOAD(z28,z29,z29,p1,p2,x10,x14,x16) +" \n\t" +CLEAR_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19) +" \n\t" +" cmp x12, #0 \n\t" // If no 4-microkernel can be applied +" b.eq K_LEFT_LOOP \n\t" +" \n\t" +" K_MKER_LOOP: \n\t" // Unroll the 4-loop. +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z30,z31,z31,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_1(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" index z29.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z28,z29,z29,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_2(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z30,z31,z31,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_3(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z28,z29,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" \n\t" +" subs x12, x12, #1 \n\t" // Decrease counter before final replica. +" b.eq FIN_MKER_LOOP \n\t" // Branch early to avoid reading excess mem. +" \n\t" +" index z29.d, xzr, x1 \n\t" +GEMMSUP_ACOL_PREFETCH_NEXT_LOAD_G(z28,z29,z29,p1,p2,x10,x15,x3,x2,x14,x16,noprfm,noprfm) +GEMM_2VX10_MKER_LOOP_PLAIN_C_4(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" b K_MKER_LOOP \n\t" +" \n\t" +" FIN_MKER_LOOP: \n\t" +GEMM_2VX10_MKER_LOOP_PLAIN_C_4_RESIDUAL(z0,z2,z4,z6,z8,z10,z12,z14,z16,z18,z1,z3,z5,z7,z9,z11,z13,z15,z17,z19,p0,z30,z31,z20,z21,z22,z23,z24,z25,z26,z27,x11,x4) +" add x10, x10, x2 \n\t" // Forward A to fill the blank. +" \n\t" +" K_LEFT_LOOP: \n\t" +" cmp x13, #0 \n\t" +" b.eq WRITE_MEM_PREP \n\t" +" \n\t" +" index z31.d, xzr, x1 \n\t" +GEMM_ACOL_GATHER_LOAD(z30,z31,z31,p1,p2,x10,x14,x16) +" ld1rd z20.d, p0/z, [x11] \n\t" +" ld1rd z21.d, p0/z, [x11, #8] \n\t" +" ld1rd z22.d, p0/z, [x11, #16] \n\t" +" ld1rd z23.d, p0/z, [x11, #24] \n\t" +" ld1rd z24.d, p0/z, [x11, #32] \n\t" +" ld1rd z25.d, p0/z, [x11, #40] \n\t" +" ld1rd z26.d, p0/z, [x11, #48] \n\t" +" ld1rd z27.d, p0/z, [x11, #56] \n\t" +" ld1rd z28.d, p0/z, [x11, #64] \n\t" +" ld1rd z29.d, p0/z, [x11, #72] \n\t" +GEMM_FMLA2(z0,z1,p0,z30,z31,z20) +GEMM_FMLA2(z2,z3,p0,z30,z31,z21) +GEMM_FMLA2(z4,z5,p0,z30,z31,z22) +GEMM_FMLA2(z6,z7,p0,z30,z31,z23) +GEMM_FMLA2(z8,z9,p0,z30,z31,z24) +GEMM_FMLA2(z10,z11,p0,z30,z31,z25) +GEMM_FMLA2(z12,z13,p0,z30,z31,z26) +GEMM_FMLA2(z14,z15,p0,z30,z31,z27) +GEMM_FMLA2(z16,z17,p0,z30,z31,z28) +GEMM_FMLA2(z18,z19,p0,z30,z31,z29) +" add x10, x10, x2 \n\t" // Forward A. +" add x11, x11, x4 \n\t" // Forward B. +" sub x13, x13, #1 \n\t" +" b K_LEFT_LOOP \n\t" // Next column / row. +" \n\t" +" WRITE_MEM_PREP: \n\t" +" \n\t" +" ldr x11, %[bi] \n\t" +" ldr x12, %[alpha] \n\t" // Load alpha & beta. +" ldr x13, %[beta] \n\t" +" ld1rd z30.d, p0/z, [x12] \n\t" +" ld1rd z31.d, p0/z, [x13] \n\t" +" ldr x12, [x12] \n\t" +" \n\t" +" cmp x8, #1 \n\t" +" b.eq PREFETCH_ABNEXT \n\t" +" prfm PLDL2STRM, [x11] \n\t" +" b WRITE_MEM \n\t" +" \n\t" +" PREFETCH_ABNEXT: \n\t" +" ldr x1, %[a_next] \n\t" // Final Millikernel loop, x1 and x2 not needed. +" ldr x2, %[b_next] \n\t" +" prfm PLDL2KEEP, [x1] \n\t" +" prfm PLDL2KEEP, [x1, 256*1] \n\t" +" prfm PLDL2KEEP, [x1, 256*2] \n\t" +" prfm PLDL2KEEP, [x1, 256*3] \n\t" +" prfm PLDL2KEEP, [x1, 256*4] \n\t" +" prfm PLDL2KEEP, [x1, 256*5] \n\t" +" prfm PLDL2KEEP, [x1, 256*6] \n\t" +" prfm PLDL2KEEP, [x1, 256*7] \n\t" +" prfm PLDL2KEEP, [x1, 256*8] \n\t" +" prfm PLDL2KEEP, [x1, 256*9] \n\t" +" prfm PLDL2KEEP, [x1, 256*10] \n\t" +" prfm PLDL2KEEP, [x1, 256*11] \n\t" +" prfm PLDL2KEEP, [x1, 256*12] \n\t" +" prfm PLDL2KEEP, [x1, 256*13] \n\t" +" prfm PLDL2KEEP, [x1, 256*14] \n\t" +" prfm PLDL2KEEP, [x1, 256*15] \n\t" +" prfm PLDL2KEEP, [x2] \n\t" +" prfm PLDL2KEEP, [x2, 256*1] \n\t" +" prfm PLDL2KEEP, [x2, 256*2] \n\t" +" prfm PLDL2KEEP, [x2, 256*3] \n\t" +" prfm PLDL2KEEP, [x2, 256*4] \n\t" +" prfm PLDL2KEEP, [x2, 256*5] \n\t" +" prfm PLDL2KEEP, [x2, 256*6] \n\t" +" prfm PLDL2KEEP, [x2, 256*7] \n\t" +" prfm PLDL2KEEP, [x2, 256*8] \n\t" +" prfm PLDL2KEEP, [x2, 256*9] \n\t" +" \n\t" +" WRITE_MEM: \n\t" +" \n\t" +" fmov d28, #1.0 \n\t" +" fmov x16, d28 \n\t" +" cmp x16, x12 \n\t" +" b.eq UNIT_ALPHA \n\t" +" \n\t" +SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30) +" \n\t" +" UNIT_ALPHA: \n\t" +" mov x9, x5 \n\t" // C address for loading. +" mov x10, x5 \n\t" // C address for storing. +" cmp x6, #1 \n\t" +" b.ne WRITE_MEM_G \n\t" +" \n\t" +" WRITE_MEM_C: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-29]. +" mov x13, xzr \n\t" // C-column's physical 1-vector skip. +" incb x13 \n\t" +GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x9,x7) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x9,x7) +" \n\t" +GEMM_C_STORE_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,x10,x7) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,x10,x7) +" b END_WRITE_MEM \n\t" +" \n\t" +" WRITE_MEM_G: \n\t" // Available scratch: Z[20-30]. +" \n\t" // Here used scratch: Z[20-30] - Z30 as index. +" mov x12, xzr \n\t" +" incb x12 \n\t" +" madd x13, x12, x6, xzr \n\t" // C-column's logical 1-vector skip. +" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8. +GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x9,x7,x13,x16) +GEMM_C_FMAD_UKER(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p1,p2,z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z31) +GEMM_C_LOAD_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x9,x7,x13,x16) +" \n\t" +GEMM_C_STORE_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p1,p2,x10,x7,x13,x16) +GEMM_C_FMAD_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p1,p2,z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z31) +GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p1,p2,x10,x7,x13,x16) +" \n\t" +" END_WRITE_MEM: \n\t" +" subs x8, x8, #1 \n\t" +" b.eq END_EXEC \n\t" +" \n\t" +" add x0, x0, x3 \n\t" // Forward A's base address to the next logic panel. +" add x5, x5, x13 \n\t" // Forward C's base address to the next logic panel. +" add x5, x5, x13 \n\t" +" b MILLIKER_MLOOP \n\t" +" \n\t" +" END_ERROR: \n\t" +" mov x0, #1 \n\t" // Return error. +" END_EXEC: \n\t" +" mov x0, #0 \n\t" // Return normal. +: +: [ai] "m" (ai), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a] "m" (ps_a), + [rs_b] "m" (rs_b), + [ci] "m" (ci), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [m_mker] "m" (m_mker), + [m_left] "m" (m_left), + [bi] "m" (bi), + [k_mker] "m" (k_mker), + [k_left] "m" (k_left), + [alpha] "m" (alpha), + [beta] "m" (beta), + [a_next] "m" (a_next), + [b_next] "m" (b_next) +: "x0","x1","x2","x3","x4","x5","x6","x7","x8", + "x9","x10","x11","x12","x13","x14","x15","x16",//"x17", + "z0","z1","z2","z3","z4","z5","z6","z7", + "z8","z9","z10","z11","z12","z13","z14","z15", + "z16","z17","z18","z19", + "z20","z21","z22","z23", + "z24","z25","z26","z27", + "z28","z29","z30","z31" + ); + } +} + diff --git a/kernels/armsve/bli_kernels_armsve.h b/kernels/armsve/bli_kernels_armsve.h index a5934312a0..3ccd79b68e 100644 --- a/kernels/armsve/bli_kernels_armsve.h +++ b/kernels/armsve/bli_kernels_armsve.h @@ -33,5 +33,13 @@ */ GEMM_UKR_PROT( double, d, gemm_armsve256_asm_8x8 ) +GEMM_UKR_PROT( double, d, gemm_armsve_asm_2vx10_unindexed ) +GEMM_UKR_PROT( float, s, gemm_armsve_asm_2vx10_unindexed ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_2vx10_unindexed ) +GEMMSUP_KER_PROT( double, d, gemmsup_cv_armsve_2vx10_unindexed ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_armsve_10x2v_unindexed ) PACKM_KER_PROT( double, d, packm_armsve256_asm_8xk ) +PACKM_KER_PROT( double, d, packm_armsve512_asm_16xk ) +PACKM_KER_PROT( double, d, packm_armsve512_asm_12xk ) +PACKM_KER_PROT( double, d, packm_armsve512_asm_10xk ) diff --git a/kernels/armv8a/3/armv8a_asm_utils.h b/kernels/armv8a/3/armv8a_asm_utils.h new file mode 100644 index 0000000000..7bf97d555c --- /dev/null +++ b/kernels/armv8a/3/armv8a_asm_utils.h @@ -0,0 +1,49 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, The University of Tokyo + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +*/ + +// Apple's local label requirements. +#if defined(__APPLE__) +#define LABEL(str) " L" #str": \n\t" +#define BEQ(str) "b.eq L" #str" \n\t" +#define BNE(str) "b.ne L" #str" \n\t" +#define BRANCH(str) "b L" #str" \n\t" +#else +#define LABEL(str) " ." #str": \n\t" +#define BEQ(str) "b.eq ." #str" \n\t" +#define BNE(str) "b.ne ." #str" \n\t" +#define BRANCH(str) "b ." #str" \n\t" +#endif + diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c index c01c67f5a0..dfdda863b1 100644 --- a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c @@ -34,6 +34,7 @@ */ #include "blis.h" +#include "armv8a_asm_utils.h" /* o 4x4 Single precision micro-kernel fully functional. @@ -81,37 +82,30 @@ __asm__ volatile " ldr x1,%[baddr] \n\t" // Load address of B. " ldr x2,%[caddr] \n\t" // Load address of C. " \n\t" -" ldr x3,%[a_next] \n\t" // Pointer to next block of A. -" ldr x4,%[b_next] \n\t" // Pointer to next pointer of B. -" \n\t" " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). " \n\t" -" ldr x7,%[alpha] \n\t" // Alpha address. -" ldr x8,%[beta] \n\t" // Beta address. -" \n\t" -" ldr x9,%[cs_c] \n\t" // Load cs_c. -" lsl x10,x9,#2 \n\t" // cs_c * sizeof(float) -- AUX. +" ldr x10,%[cs_c] \n\t" // Load cs_c. +" lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. " \n\t" -" ldr x13,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x13,#2 \n\t" // rs_c * sizeof(float). +" ldr x14,%[rs_c] \n\t" // Load rs_c. +" lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). " \n\t" " add x16,x2,x10 \n\t" //Load address Column 1 of C " add x17,x16,x10 \n\t" //Load address Column 2 of C -" add x18,x17,x10 \n\t" //Load address Column 3 of C -" add x19,x18,x10 \n\t" //Load address Column 4 of C -" add x20,x19,x10 \n\t" //Load address Column 5 of C -" add x21,x20,x10 \n\t" //Load address Column 6 of C -" add x22,x21,x10 \n\t" //Load address Column 7 of C -" add x23,x22,x10 \n\t" //Load address Column 8 of C -" add x24,x23,x10 \n\t" //Load address Column 9 of C -" add x25,x24,x10 \n\t" //Load address Column 10 of C -" add x26,x25,x10 \n\t" //Load address Column 11 of C +" add x19,x17,x10 \n\t" //Load address Column 3 of C +" add x20,x19,x10 \n\t" //Load address Column 4 of C +" add x21,x20,x10 \n\t" //Load address Column 5 of C +" add x22,x21,x10 \n\t" //Load address Column 6 of C +" add x23,x22,x10 \n\t" //Load address Column 7 of C +" add x24,x23,x10 \n\t" //Load address Column 8 of C +" add x25,x24,x10 \n\t" //Load address Column 9 of C +" add x26,x25,x10 \n\t" //Load address Column 10 of C +" add x27,x26,x10 \n\t" //Load address Column 11 of C " \n\t" " prfm pldl1keep,[x2] \n\t" // Prefetch c. " prfm pldl1keep,[x16] \n\t" // Prefetch c. " prfm pldl1keep,[x17] \n\t" // Prefetch c. -" prfm pldl1keep,[x18] \n\t" // Prefetch c. " prfm pldl1keep,[x19] \n\t" // Prefetch c. " prfm pldl1keep,[x20] \n\t" // Prefetch c. " prfm pldl1keep,[x21] \n\t" // Prefetch c. @@ -120,6 +114,7 @@ __asm__ volatile " prfm pldl1keep,[x24] \n\t" // Prefetch c. " prfm pldl1keep,[x25] \n\t" // Prefetch c. " prfm pldl1keep,[x26] \n\t" // Prefetch c. +" prfm pldl1keep,[x27] \n\t" // Prefetch c. " \n\t" " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 " prfm PLDL1KEEP, [x1, #192] \n\t" @@ -155,7 +150,7 @@ __asm__ volatile " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 " \n\t" " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -" beq .SCONSIDERKLEFT \n\t" +BEQ(SCONSIDERKLEFT) " \n\t" " ldr q0, [x0] \n\t" " ldr q1, [x0, #16] \n\t" // Load a @@ -168,9 +163,9 @@ __asm__ volatile " add x1, x1, #48 \n\t" //update address of B " \n\t" " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -" beq .SLASTITER \n\t" // (as loop is do-while-like). +BEQ(SLASTITER) // (as loop is do-while-like). " \n\t" -" .SLOOPKITER: \n\t" // Body of the k_iter loop. +LABEL(SLOOPKITER) // Body of the k_iter loop. " \n\t" " ldr q5, [x0] \n\t" " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. @@ -316,9 +311,9 @@ __asm__ volatile " \n\t" //End It 4 " sub x5,x5,1 \n\t" // i-=1. " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -" bne .SLOOPKITER \n\t" +BNE(SLOOPKITER) " \n\t" -" .SLASTITER: \n\t" // Last iteration of k_iter loop. +LABEL(SLASTITER) // Last iteration of k_iter loop. " \n\t" " \n\t" " ldr q5, [x0] \n\t" @@ -454,11 +449,11 @@ __asm__ volatile " add x0, x0, #96 \n\t" " \n\t" //End It 4 " \n\t" -" .SCONSIDERKLEFT: \n\t" +LABEL(SCONSIDERKLEFT) " cmp x6,0 \n\t" // If k_left == 0, we are done. -" beq .SPOSTACCUM \n\t" // else, we enter the k_left loop. +BEQ(SPOSTACCUM) // else, we enter the k_left loop. " \n\t" -" .SLOOPKLEFT: \n\t" // Body of the left iterations +LABEL(SLOOPKLEFT) // Body of the left iterations " \n\t" " ldr q0, [x0],#16 \n\t" " ldr q1, [x0],#16 \n\t" // Load a @@ -497,17 +492,23 @@ __asm__ volatile " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. " \n\t" " cmp x6,0 \n\t" // Iterate again. -" bne .SLOOPKLEFT \n\t" // if i!=0. +BNE(SLOOPKLEFT) // if i!=0. +" \n\t" +LABEL(SPOSTACCUM) +" \n\t" +" ldr x0,%[alpha] \n\t" // Alpha address. +" ldr x1,%[beta] \n\t" // Beta address. " \n\t" -" .SPOSTACCUM: \n\t" +" ld1r {v6.4s},[x0] \n\t" // Load alpha. +" ld1r {v7.4s},[x1] \n\t" // Load beta " \n\t" -" ld1r {v6.4s},[x7] \n\t" // Load alpha. -" ld1r {v7.4s},[x8] \n\t" // Load beta +" ldr x0,%[a_next] \n\t" // Pointer to next block of A. +" ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. " \n\t" -" cmp x13,#1 \n\t" // If rs_c != 1 (column-major) -" bne .SGENSTORED \n\t" +" cmp x14,#4 \n\t" // If rs_c != 1 (column-major) +BNE(SGENSTORED) " \n\t" -" .SCOLSTORED: \n\t" // C is column-major. +LABEL(SCOLSTORED) // C is column-major. " \n\t" " dup v0.4s, wzr \n\t" " dup v1.4s, wzr \n\t" @@ -517,7 +518,7 @@ __asm__ volatile " dup v5.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS1 \n\t" // Taking care of the beta==0 case. +BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. " \n\t" " ldr q0, [x2] \n\t" //Load column 0 of C " ldr q1, [x2, #16] \n\t" @@ -533,7 +534,7 @@ __asm__ volatile " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROCOLSTOREDS1: \n\t" +LABEL(SBETAZEROCOLSTOREDS1) " \n\t" " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha @@ -557,14 +558,14 @@ __asm__ volatile " dup v13.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS2 \n\t" // Taking care of the beta==0 case. +BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. " \n\t" -" ldr q8, [x18] \n\t" //Load column 3 of C -" ldr q9, [x18, #16] \n\t" -" ldr q10, [x19] \n\t" //Load column 4 of C -" ldr q11, [x19, #16] \n\t" -" ldr q12, [x20] \n\t" //Load column 5 of C -" ldr q13, [x20, #16] \n\t" +" ldr q8, [x19] \n\t" //Load column 3 of C +" ldr q9, [x19, #16] \n\t" +" ldr q10, [x20] \n\t" //Load column 4 of C +" ldr q11, [x20, #16] \n\t" +" ldr q12, [x21] \n\t" //Load column 5 of C +" ldr q13, [x21, #16] \n\t" " \n\t" " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta @@ -573,7 +574,7 @@ __asm__ volatile " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROCOLSTOREDS2: \n\t" +LABEL(SBETAZEROCOLSTOREDS2) " \n\t" " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha @@ -582,12 +583,12 @@ __asm__ volatile " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" str q8, [x18] \n\t" //Store column 3 of C -" str q9, [x18, #16] \n\t" -" str q10, [x19] \n\t" //Store column 4 of C -" str q11, [x19, #16] \n\t" -" str q12, [x20] \n\t" //Store column 5 of C -" str q13, [x20, #16] \n\t" +" str q8, [x19] \n\t" //Store column 3 of C +" str q9, [x19, #16] \n\t" +" str q10, [x20] \n\t" //Store column 4 of C +" str q11, [x20, #16] \n\t" +" str q12, [x21] \n\t" //Store column 5 of C +" str q13, [x21, #16] \n\t" " \n\t" " dup v0.4s, wzr \n\t" " dup v1.4s, wzr \n\t" @@ -597,14 +598,14 @@ __asm__ volatile " dup v5.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS3 \n\t" // Taking care of the beta==0 case. +BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. " \n\t" -" ldr q0, [x21] \n\t" //Load column 6 of C -" ldr q1, [x21, #16] \n\t" -" ldr q2, [x22] \n\t" //Load column 7 of C -" ldr q3, [x22, #16] \n\t" -" ldr q4, [x23] \n\t" //Load column 8 of C -" ldr q5, [x23, #16] \n\t" +" ldr q0, [x22] \n\t" //Load column 6 of C +" ldr q1, [x22, #16] \n\t" +" ldr q2, [x23] \n\t" //Load column 7 of C +" ldr q3, [x23, #16] \n\t" +" ldr q4, [x24] \n\t" //Load column 8 of C +" ldr q5, [x24, #16] \n\t" " \n\t" " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta @@ -613,7 +614,7 @@ __asm__ volatile " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROCOLSTOREDS3: \n\t" +LABEL(SBETAZEROCOLSTOREDS3) " \n\t" " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha @@ -622,12 +623,12 @@ __asm__ volatile " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" str q0, [x21] \n\t" //Store column 6 of C -" str q1, [x21, #16] \n\t" -" str q2, [x22] \n\t" //Store column 7 of C -" str q3, [x22, #16] \n\t" -" str q4, [x23] \n\t" //Store column 8 of C -" str q5, [x23, #16] \n\t" +" str q0, [x22] \n\t" //Store column 6 of C +" str q1, [x22, #16] \n\t" +" str q2, [x23] \n\t" //Store column 7 of C +" str q3, [x23, #16] \n\t" +" str q4, [x24] \n\t" //Store column 8 of C +" str q5, [x24, #16] \n\t" " \n\t" " dup v8.4s, wzr \n\t" " dup v9.4s, wzr \n\t" @@ -637,14 +638,14 @@ __asm__ volatile " dup v13.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROCOLSTOREDS4 \n\t" // Taking care of the beta==0 case. +BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. " \n\t" -" ldr q8, [x24] \n\t" //Load column 9 of C -" ldr q9, [x24, #16] \n\t" -" ldr q10, [x25] \n\t" //Load column 10 of C -" ldr q11, [x25, #16] \n\t" -" ldr q12, [x26] \n\t" //Load column 11 of C -" ldr q13, [x26, #16] \n\t" +" ldr q8, [x25] \n\t" //Load column 9 of C +" ldr q9, [x25, #16] \n\t" +" ldr q10, [x26] \n\t" //Load column 10 of C +" ldr q11, [x26, #16] \n\t" +" ldr q12, [x27] \n\t" //Load column 11 of C +" ldr q13, [x27, #16] \n\t" " \n\t" " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta @@ -653,10 +654,10 @@ __asm__ volatile " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROCOLSTOREDS4: \n\t" +LABEL(SBETAZEROCOLSTOREDS4) " \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" +" prfm pldl2keep,[x0] \n\t" +" prfm pldl2keep,[x1] \n\t" " \n\t" " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha @@ -665,18 +666,18 @@ __asm__ volatile " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" str q8, [x24] \n\t" //Store column 9 of C -" str q9, [x24, #16] \n\t" -" str q10, [x25] \n\t" //Store column 10 of C -" str q11, [x25, #16] \n\t" -" str q12, [x26] \n\t" //Store column 11 of C -" str q13, [x26, #16] \n\t" +" str q8, [x25] \n\t" //Store column 9 of C +" str q9, [x25, #16] \n\t" +" str q10, [x26] \n\t" //Store column 10 of C +" str q11, [x26, #16] \n\t" +" str q12, [x27] \n\t" //Store column 11 of C +" str q13, [x27, #16] \n\t" " \n\t" " \n\t" -" b .SEND \n\t" // Done (TODO: this obviously needs to be moved down to remove jump). +BRANCH(SEND) // Done. " \n\t" " \n\t" -" .SGENSTORED: \n\t" // C is general-stride stored. +LABEL(SGENSTORED) // C is general-stride stored. " \n\t" " \n\t" " dup v0.4s, wzr \n\t" @@ -687,40 +688,40 @@ __asm__ volatile " dup v5.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS1 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x2 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x27],x14 \n\t" // Load c06 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x27],x14 \n\t" // Load c07 into quad and increment by rs_c. -" \n\t" -" mov x27, x16 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x27],x14 \n\t" // Load c16 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x27],x14 \n\t" // Load c17 into quad and increment by rs_c. -" \n\t" -" mov x27, x17 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x27],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x27],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x27],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x27],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x27],x14 \n\t" // Load c26 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x27],x14 \n\t" // Load c27 into quad and increment by rs_c. +BEQ(SBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. +" \n\t" +" mov x5, x2 \n\t" +" \n\t" +" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c00 into quad and increment by rs_c. +" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c01 into quad and increment by rs_c. +" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c02 into quad and increment by rs_c. +" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c03 into quad and increment by rs_c. +" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c04 into quad and increment by rs_c. +" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c05 into quad and increment by rs_c. +" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c06 into quad and increment by rs_c. +" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c07 into quad and increment by rs_c. +" \n\t" +" mov x5, x16 \n\t" +" \n\t" +" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c10 into quad and increment by rs_c. +" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c11 into quad and increment by rs_c. +" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c12 into quad and increment by rs_c. +" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c13 into quad and increment by rs_c. +" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c14 into quad and increment by rs_c. +" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c15 into quad and increment by rs_c. +" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c16 into quad and increment by rs_c. +" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c17 into quad and increment by rs_c. +" \n\t" +" mov x5, x17 \n\t" +" \n\t" +" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c20 into quad and increment by rs_c. +" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c21 into quad and increment by rs_c. +" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c22 into quad and increment by rs_c. +" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c23 into quad and increment by rs_c. +" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c24 into quad and increment by rs_c. +" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c25 into quad and increment by rs_c. +" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c26 into quad and increment by rs_c. +" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c27 into quad and increment by rs_c. " \n\t" " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta @@ -729,7 +730,7 @@ __asm__ volatile " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROGENSTOREDS1: \n\t" +LABEL(SBETAZEROGENSTOREDS1) " \n\t" " fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha " fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha @@ -738,38 +739,38 @@ __asm__ volatile " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" mov x27, x2 \n\t" -" \n\t" -" st1 {v0.s}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.s}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v0.s}[2],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v0.s}[3],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v1.s}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v1.s}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. -" st1 {v1.s}[2],[x27],x14 \n\t" // Store c06 into quad and increment by rs_c. -" st1 {v1.s}[3],[x27],x14 \n\t" // Store c07 into quad and increment by rs_c. -" \n\t" -" mov x27, x16 \n\t" -" \n\t" -" st1 {v2.s}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v2.s}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v2.s}[2],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v2.s}[3],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v3.s}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v3.s}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. -" st1 {v3.s}[2],[x27],x14 \n\t" // Store c16 into quad and increment by rs_c. -" st1 {v3.s}[3],[x27],x14 \n\t" // Store c17 into quad and increment by rs_c. -" \n\t" -" mov x27, x17 \n\t" -" \n\t" -" st1 {v4.s}[0],[x27],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v4.s}[1],[x27],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v4.s}[2],[x27],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v4.s}[3],[x27],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v5.s}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v5.s}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. -" st1 {v5.s}[2],[x27],x14 \n\t" // Store c26 into quad and increment by rs_c. -" st1 {v5.s}[3],[x27],x14 \n\t" // Store c27 into quad and increment by rs_c. +" mov x5, x2 \n\t" +" \n\t" +" st1 {v0.s}[0],[x5],x14 \n\t" // Store c00 into quad and increment by rs_c. +" st1 {v0.s}[1],[x5],x14 \n\t" // Store c01 into quad and increment by rs_c. +" st1 {v0.s}[2],[x5],x14 \n\t" // Store c02 into quad and increment by rs_c. +" st1 {v0.s}[3],[x5],x14 \n\t" // Store c03 into quad and increment by rs_c. +" st1 {v1.s}[0],[x5],x14 \n\t" // Store c04 into quad and increment by rs_c. +" st1 {v1.s}[1],[x5],x14 \n\t" // Store c05 into quad and increment by rs_c. +" st1 {v1.s}[2],[x5],x14 \n\t" // Store c06 into quad and increment by rs_c. +" st1 {v1.s}[3],[x5],x14 \n\t" // Store c07 into quad and increment by rs_c. +" \n\t" +" mov x5, x16 \n\t" +" \n\t" +" st1 {v2.s}[0],[x5],x14 \n\t" // Store c10 into quad and increment by rs_c. +" st1 {v2.s}[1],[x5],x14 \n\t" // Store c11 into quad and increment by rs_c. +" st1 {v2.s}[2],[x5],x14 \n\t" // Store c12 into quad and increment by rs_c. +" st1 {v2.s}[3],[x5],x14 \n\t" // Store c13 into quad and increment by rs_c. +" st1 {v3.s}[0],[x5],x14 \n\t" // Store c14 into quad and increment by rs_c. +" st1 {v3.s}[1],[x5],x14 \n\t" // Store c15 into quad and increment by rs_c. +" st1 {v3.s}[2],[x5],x14 \n\t" // Store c16 into quad and increment by rs_c. +" st1 {v3.s}[3],[x5],x14 \n\t" // Store c17 into quad and increment by rs_c. +" \n\t" +" mov x5, x17 \n\t" +" \n\t" +" st1 {v4.s}[0],[x5],x14 \n\t" // Store c20 into quad and increment by rs_c. +" st1 {v4.s}[1],[x5],x14 \n\t" // Store c21 into quad and increment by rs_c. +" st1 {v4.s}[2],[x5],x14 \n\t" // Store c22 into quad and increment by rs_c. +" st1 {v4.s}[3],[x5],x14 \n\t" // Store c23 into quad and increment by rs_c. +" st1 {v5.s}[0],[x5],x14 \n\t" // Store c24 into quad and increment by rs_c. +" st1 {v5.s}[1],[x5],x14 \n\t" // Store c25 into quad and increment by rs_c. +" st1 {v5.s}[2],[x5],x14 \n\t" // Store c26 into quad and increment by rs_c. +" st1 {v5.s}[3],[x5],x14 \n\t" // Store c27 into quad and increment by rs_c. " \n\t" " dup v8.4s, wzr \n\t" " dup v9.4s, wzr \n\t" @@ -779,40 +780,40 @@ __asm__ volatile " dup v13.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS2 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x18 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x27],x14 \n\t" // Load c36 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x27],x14 \n\t" // Load c37 into quad and increment by rs_c. -" \n\t" -" mov x27, x19 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x27],x14 \n\t" // Load c46 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x27],x14 \n\t" // Load c47 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x27],x14 \n\t" // Load c56 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x27],x14 \n\t" // Load c57 into quad and increment by rs_c. +BEQ(SBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. +" \n\t" +" mov x5, x19 \n\t" +" \n\t" +" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c30 into quad and increment by rs_c. +" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c31 into quad and increment by rs_c. +" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c32 into quad and increment by rs_c. +" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c33 into quad and increment by rs_c. +" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c34 into quad and increment by rs_c. +" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c35 into quad and increment by rs_c. +" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c36 into quad and increment by rs_c. +" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c37 into quad and increment by rs_c. +" \n\t" +" mov x5, x20 \n\t" +" \n\t" +" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c40 into quad and increment by rs_c. +" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c41 into quad and increment by rs_c. +" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c42 into quad and increment by rs_c. +" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c43 into quad and increment by rs_c. +" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c44 into quad and increment by rs_c. +" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c45 into quad and increment by rs_c. +" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c46 into quad and increment by rs_c. +" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c47 into quad and increment by rs_c. +" \n\t" +" mov x5, x21 \n\t" +" \n\t" +" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c50 into quad and increment by rs_c. +" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c51 into quad and increment by rs_c. +" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c52 into quad and increment by rs_c. +" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c53 into quad and increment by rs_c. +" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c54 into quad and increment by rs_c. +" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c55 into quad and increment by rs_c. +" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c56 into quad and increment by rs_c. +" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c57 into quad and increment by rs_c. " \n\t" " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta @@ -821,7 +822,7 @@ __asm__ volatile " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROGENSTOREDS2: \n\t" +LABEL(SBETAZEROGENSTOREDS2) " \n\t" " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha @@ -830,38 +831,38 @@ __asm__ volatile " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" mov x27, x18 \n\t" -" \n\t" -" st1 {v8.s}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v8.s}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v8.s}[2],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v8.s}[3],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v9.s}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v9.s}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. -" st1 {v9.s}[2],[x27],x14 \n\t" // Store c36 into quad and increment by rs_c. -" st1 {v9.s}[3],[x27],x14 \n\t" // Store c37 into quad and increment by rs_c. -" \n\t" -" mov x27, x19 \n\t" -" \n\t" -" st1 {v10.s}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v10.s}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v10.s}[2],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v10.s}[3],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v11.s}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v11.s}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. -" st1 {v11.s}[2],[x27],x14 \n\t" // Store c46 into quad and increment by rs_c. -" st1 {v11.s}[3],[x27],x14 \n\t" // Store c47 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" -" \n\t" -" st1 {v12.s}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v12.s}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v12.s}[2],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v12.s}[3],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v13.s}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v13.s}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. -" st1 {v13.s}[2],[x27],x14 \n\t" // Store c56 into quad and increment by rs_c. -" st1 {v13.s}[3],[x27],x14 \n\t" // Store c57 into quad and increment by rs_c. +" mov x5, x19 \n\t" +" \n\t" +" st1 {v8.s}[0],[x5],x14 \n\t" // Store c30 into quad and increment by rs_c. +" st1 {v8.s}[1],[x5],x14 \n\t" // Store c31 into quad and increment by rs_c. +" st1 {v8.s}[2],[x5],x14 \n\t" // Store c32 into quad and increment by rs_c. +" st1 {v8.s}[3],[x5],x14 \n\t" // Store c33 into quad and increment by rs_c. +" st1 {v9.s}[0],[x5],x14 \n\t" // Store c34 into quad and increment by rs_c. +" st1 {v9.s}[1],[x5],x14 \n\t" // Store c35 into quad and increment by rs_c. +" st1 {v9.s}[2],[x5],x14 \n\t" // Store c36 into quad and increment by rs_c. +" st1 {v9.s}[3],[x5],x14 \n\t" // Store c37 into quad and increment by rs_c. +" \n\t" +" mov x5, x20 \n\t" +" \n\t" +" st1 {v10.s}[0],[x5],x14 \n\t" // Store c40 into quad and increment by rs_c. +" st1 {v10.s}[1],[x5],x14 \n\t" // Store c41 into quad and increment by rs_c. +" st1 {v10.s}[2],[x5],x14 \n\t" // Store c42 into quad and increment by rs_c. +" st1 {v10.s}[3],[x5],x14 \n\t" // Store c43 into quad and increment by rs_c. +" st1 {v11.s}[0],[x5],x14 \n\t" // Store c44 into quad and increment by rs_c. +" st1 {v11.s}[1],[x5],x14 \n\t" // Store c45 into quad and increment by rs_c. +" st1 {v11.s}[2],[x5],x14 \n\t" // Store c46 into quad and increment by rs_c. +" st1 {v11.s}[3],[x5],x14 \n\t" // Store c47 into quad and increment by rs_c. +" \n\t" +" mov x5, x21 \n\t" +" \n\t" +" st1 {v12.s}[0],[x5],x14 \n\t" // Store c50 into quad and increment by rs_c. +" st1 {v12.s}[1],[x5],x14 \n\t" // Store c51 into quad and increment by rs_c. +" st1 {v12.s}[2],[x5],x14 \n\t" // Store c52 into quad and increment by rs_c. +" st1 {v12.s}[3],[x5],x14 \n\t" // Store c53 into quad and increment by rs_c. +" st1 {v13.s}[0],[x5],x14 \n\t" // Store c54 into quad and increment by rs_c. +" st1 {v13.s}[1],[x5],x14 \n\t" // Store c55 into quad and increment by rs_c. +" st1 {v13.s}[2],[x5],x14 \n\t" // Store c56 into quad and increment by rs_c. +" st1 {v13.s}[3],[x5],x14 \n\t" // Store c57 into quad and increment by rs_c. " \n\t" " dup v0.4s, wzr \n\t" " dup v1.4s, wzr \n\t" @@ -871,40 +872,40 @@ __asm__ volatile " dup v5.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS3 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x21 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x27],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x27],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x27],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x27],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x27],x14 \n\t" // Load c66 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x27],x14 \n\t" // Load c67 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x27],x14 \n\t" // Load c76 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x27],x14 \n\t" // Load c77 into quad and increment by rs_c. -" \n\t" -" mov x27, x23 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x27],x14 \n\t" // Load c80 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x27],x14 \n\t" // Load c81 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x27],x14 \n\t" // Load c82 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x27],x14 \n\t" // Load c83 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x27],x14 \n\t" // Load c84 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x27],x14 \n\t" // Load c85 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x27],x14 \n\t" // Load c86 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x27],x14 \n\t" // Load c87 into quad and increment by rs_c. +BEQ(SBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. +" \n\t" +" mov x5, x22 \n\t" +" \n\t" +" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c60 into quad and increment by rs_c. +" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c61 into quad and increment by rs_c. +" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c62 into quad and increment by rs_c. +" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c63 into quad and increment by rs_c. +" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c64 into quad and increment by rs_c. +" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c65 into quad and increment by rs_c. +" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c66 into quad and increment by rs_c. +" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c67 into quad and increment by rs_c. +" \n\t" +" mov x5, x23 \n\t" +" \n\t" +" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c70 into quad and increment by rs_c. +" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c71 into quad and increment by rs_c. +" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c72 into quad and increment by rs_c. +" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c73 into quad and increment by rs_c. +" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c74 into quad and increment by rs_c. +" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c75 into quad and increment by rs_c. +" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c76 into quad and increment by rs_c. +" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c77 into quad and increment by rs_c. +" \n\t" +" mov x5, x24 \n\t" +" \n\t" +" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c80 into quad and increment by rs_c. +" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c81 into quad and increment by rs_c. +" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c82 into quad and increment by rs_c. +" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c83 into quad and increment by rs_c. +" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c84 into quad and increment by rs_c. +" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c85 into quad and increment by rs_c. +" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c86 into quad and increment by rs_c. +" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c87 into quad and increment by rs_c. " \n\t" " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta @@ -913,7 +914,7 @@ __asm__ volatile " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROGENSTOREDS3: \n\t" +LABEL(SBETAZEROGENSTOREDS3) " \n\t" " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha @@ -922,38 +923,38 @@ __asm__ volatile " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" mov x27, x21 \n\t" -" \n\t" -" st1 {v0.s}[0],[x27],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v0.s}[1],[x27],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v0.s}[2],[x27],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v0.s}[3],[x27],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v1.s}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v1.s}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. -" st1 {v1.s}[2],[x27],x14 \n\t" // Store c66 into quad and increment by rs_c. -" st1 {v1.s}[3],[x27],x14 \n\t" // Store c67 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" -" \n\t" -" st1 {v2.s}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v2.s}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v2.s}[2],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v2.s}[3],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v3.s}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v3.s}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. -" st1 {v3.s}[2],[x27],x14 \n\t" // Store c76 into quad and increment by rs_c. -" st1 {v3.s}[3],[x27],x14 \n\t" // Store c77 into quad and increment by rs_c. -" \n\t" -" mov x27, x23 \n\t" -" \n\t" -" st1 {v4.s}[0],[x27],x14 \n\t" // Store c80 into quad and increment by rs_c. -" st1 {v4.s}[1],[x27],x14 \n\t" // Store c81 into quad and increment by rs_c. -" st1 {v4.s}[2],[x27],x14 \n\t" // Store c82 into quad and increment by rs_c. -" st1 {v4.s}[3],[x27],x14 \n\t" // Store c83 into quad and increment by rs_c. -" st1 {v5.s}[0],[x27],x14 \n\t" // Store c84 into quad and increment by rs_c. -" st1 {v5.s}[1],[x27],x14 \n\t" // Store c85 into quad and increment by rs_c. -" st1 {v5.s}[2],[x27],x14 \n\t" // Store c86 into quad and increment by rs_c. -" st1 {v5.s}[3],[x27],x14 \n\t" // Store c87 into quad and increment by rs_c. +" mov x5, x22 \n\t" +" \n\t" +" st1 {v0.s}[0],[x5],x14 \n\t" // Store c60 into quad and increment by rs_c. +" st1 {v0.s}[1],[x5],x14 \n\t" // Store c61 into quad and increment by rs_c. +" st1 {v0.s}[2],[x5],x14 \n\t" // Store c62 into quad and increment by rs_c. +" st1 {v0.s}[3],[x5],x14 \n\t" // Store c63 into quad and increment by rs_c. +" st1 {v1.s}[0],[x5],x14 \n\t" // Store c64 into quad and increment by rs_c. +" st1 {v1.s}[1],[x5],x14 \n\t" // Store c65 into quad and increment by rs_c. +" st1 {v1.s}[2],[x5],x14 \n\t" // Store c66 into quad and increment by rs_c. +" st1 {v1.s}[3],[x5],x14 \n\t" // Store c67 into quad and increment by rs_c. +" \n\t" +" mov x5, x23 \n\t" +" \n\t" +" st1 {v2.s}[0],[x5],x14 \n\t" // Store c70 into quad and increment by rs_c. +" st1 {v2.s}[1],[x5],x14 \n\t" // Store c71 into quad and increment by rs_c. +" st1 {v2.s}[2],[x5],x14 \n\t" // Store c72 into quad and increment by rs_c. +" st1 {v2.s}[3],[x5],x14 \n\t" // Store c73 into quad and increment by rs_c. +" st1 {v3.s}[0],[x5],x14 \n\t" // Store c74 into quad and increment by rs_c. +" st1 {v3.s}[1],[x5],x14 \n\t" // Store c75 into quad and increment by rs_c. +" st1 {v3.s}[2],[x5],x14 \n\t" // Store c76 into quad and increment by rs_c. +" st1 {v3.s}[3],[x5],x14 \n\t" // Store c77 into quad and increment by rs_c. +" \n\t" +" mov x5, x24 \n\t" +" \n\t" +" st1 {v4.s}[0],[x5],x14 \n\t" // Store c80 into quad and increment by rs_c. +" st1 {v4.s}[1],[x5],x14 \n\t" // Store c81 into quad and increment by rs_c. +" st1 {v4.s}[2],[x5],x14 \n\t" // Store c82 into quad and increment by rs_c. +" st1 {v4.s}[3],[x5],x14 \n\t" // Store c83 into quad and increment by rs_c. +" st1 {v5.s}[0],[x5],x14 \n\t" // Store c84 into quad and increment by rs_c. +" st1 {v5.s}[1],[x5],x14 \n\t" // Store c85 into quad and increment by rs_c. +" st1 {v5.s}[2],[x5],x14 \n\t" // Store c86 into quad and increment by rs_c. +" st1 {v5.s}[3],[x5],x14 \n\t" // Store c87 into quad and increment by rs_c. " \n\t" " dup v8.4s, wzr \n\t" " dup v9.4s, wzr \n\t" @@ -963,40 +964,40 @@ __asm__ volatile " dup v13.4s, wzr \n\t" " \n\t" " fcmp s7,#0.0 \n\t" -" beq .SBETAZEROGENSTOREDS4 \n\t" // Taking care of the beta==0 case. -" \n\t" -" mov x27, x24 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x27],x14 \n\t" // Load c90 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x27],x14 \n\t" // Load c91 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x27],x14 \n\t" // Load c92 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x27],x14 \n\t" // Load c93 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x27],x14 \n\t" // Load c94 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x27],x14 \n\t" // Load c95 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x27],x14 \n\t" // Load c96 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x27],x14 \n\t" // Load c97 into quad and increment by rs_c. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x27],x14 \n\t" // Load c100 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x27],x14 \n\t" // Load c101 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x27],x14 \n\t" // Load c102 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x27],x14 \n\t" // Load c103 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x27],x14 \n\t" // Load c104 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x27],x14 \n\t" // Load c105 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x27],x14 \n\t" // Load c106 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x27],x14 \n\t" // Load c107 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x27],x14 \n\t" // Load c110 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x27],x14 \n\t" // Load c111 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x27],x14 \n\t" // Load c112 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x27],x14 \n\t" // Load c113 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x27],x14 \n\t" // Load c114 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x27],x14 \n\t" // Load c115 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x27],x14 \n\t" // Load c116 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x27],x14 \n\t" // Load c117 into quad and increment by rs_c. +BEQ(SBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. +" \n\t" +" mov x5, x25 \n\t" +" \n\t" +" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c90 into quad and increment by rs_c. +" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c91 into quad and increment by rs_c. +" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c92 into quad and increment by rs_c. +" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c93 into quad and increment by rs_c. +" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c94 into quad and increment by rs_c. +" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c95 into quad and increment by rs_c. +" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c96 into quad and increment by rs_c. +" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c97 into quad and increment by rs_c. +" \n\t" +" mov x5, x26 \n\t" +" \n\t" +" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c100 into quad and increment by rs_c. +" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c101 into quad and increment by rs_c. +" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c102 into quad and increment by rs_c. +" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c103 into quad and increment by rs_c. +" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c104 into quad and increment by rs_c. +" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c105 into quad and increment by rs_c. +" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c106 into quad and increment by rs_c. +" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c107 into quad and increment by rs_c. +" \n\t" +" mov x5, x27 \n\t" +" \n\t" +" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c110 into quad and increment by rs_c. +" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c111 into quad and increment by rs_c. +" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c112 into quad and increment by rs_c. +" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c113 into quad and increment by rs_c. +" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c114 into quad and increment by rs_c. +" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c115 into quad and increment by rs_c. +" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c116 into quad and increment by rs_c. +" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c117 into quad and increment by rs_c. " \n\t" " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta @@ -1005,10 +1006,10 @@ __asm__ volatile " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta " \n\t" -" .SBETAZEROGENSTOREDS4: \n\t" +LABEL(SBETAZEROGENSTOREDS4) " \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" +" prfm pldl2keep,[x0] \n\t" +" prfm pldl2keep,[x1] \n\t" " \n\t" " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha @@ -1017,40 +1018,40 @@ __asm__ volatile " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha " \n\t" -" mov x27, x24 \n\t" -" \n\t" -" st1 {v8.s}[0],[x27],x14 \n\t" // Store c90 into quad and increment by rs_c. -" st1 {v8.s}[1],[x27],x14 \n\t" // Store c91 into quad and increment by rs_c. -" st1 {v8.s}[2],[x27],x14 \n\t" // Store c92 into quad and increment by rs_c. -" st1 {v8.s}[3],[x27],x14 \n\t" // Store c93 into quad and increment by rs_c. -" st1 {v9.s}[0],[x27],x14 \n\t" // Store c94 into quad and increment by rs_c. -" st1 {v9.s}[1],[x27],x14 \n\t" // Store c95 into quad and increment by rs_c. -" st1 {v9.s}[2],[x27],x14 \n\t" // Store c96 into quad and increment by rs_c. -" st1 {v9.s}[3],[x27],x14 \n\t" // Store c97 into quad and increment by rs_c. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" st1 {v10.s}[0],[x27],x14 \n\t" // Store c100 into quad and increment by rs_c. -" st1 {v10.s}[1],[x27],x14 \n\t" // Store c101 into quad and increment by rs_c. -" st1 {v10.s}[2],[x27],x14 \n\t" // Store c102 into quad and increment by rs_c. -" st1 {v10.s}[3],[x27],x14 \n\t" // Store c103 into quad and increment by rs_c. -" st1 {v11.s}[0],[x27],x14 \n\t" // Store c104 into quad and increment by rs_c. -" st1 {v11.s}[1],[x27],x14 \n\t" // Store c105 into quad and increment by rs_c. -" st1 {v11.s}[2],[x27],x14 \n\t" // Store c106 into quad and increment by rs_c. -" st1 {v11.s}[3],[x27],x14 \n\t" // Store c107 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" -" \n\t" -" st1 {v12.s}[0],[x27],x14 \n\t" // Store c110 into quad and increment by rs_c. -" st1 {v12.s}[1],[x27],x14 \n\t" // Store c111 into quad and increment by rs_c. -" st1 {v12.s}[2],[x27],x14 \n\t" // Store c112 into quad and increment by rs_c. -" st1 {v12.s}[3],[x27],x14 \n\t" // Store c113 into quad and increment by rs_c. -" st1 {v13.s}[0],[x27],x14 \n\t" // Store c114 into quad and increment by rs_c. -" st1 {v13.s}[1],[x27],x14 \n\t" // Store c115 into quad and increment by rs_c. -" st1 {v13.s}[2],[x27],x14 \n\t" // Store c116 into quad and increment by rs_c. -" st1 {v13.s}[3],[x27],x14 \n\t" // Store c147 into quad and increment by rs_c. -" \n\t" -" .SEND: \n\t" // Done! +" mov x5, x25 \n\t" +" \n\t" +" st1 {v8.s}[0],[x5],x14 \n\t" // Store c90 into quad and increment by rs_c. +" st1 {v8.s}[1],[x5],x14 \n\t" // Store c91 into quad and increment by rs_c. +" st1 {v8.s}[2],[x5],x14 \n\t" // Store c92 into quad and increment by rs_c. +" st1 {v8.s}[3],[x5],x14 \n\t" // Store c93 into quad and increment by rs_c. +" st1 {v9.s}[0],[x5],x14 \n\t" // Store c94 into quad and increment by rs_c. +" st1 {v9.s}[1],[x5],x14 \n\t" // Store c95 into quad and increment by rs_c. +" st1 {v9.s}[2],[x5],x14 \n\t" // Store c96 into quad and increment by rs_c. +" st1 {v9.s}[3],[x5],x14 \n\t" // Store c97 into quad and increment by rs_c. +" \n\t" +" mov x5, x26 \n\t" +" \n\t" +" st1 {v10.s}[0],[x5],x14 \n\t" // Store c100 into quad and increment by rs_c. +" st1 {v10.s}[1],[x5],x14 \n\t" // Store c101 into quad and increment by rs_c. +" st1 {v10.s}[2],[x5],x14 \n\t" // Store c102 into quad and increment by rs_c. +" st1 {v10.s}[3],[x5],x14 \n\t" // Store c103 into quad and increment by rs_c. +" st1 {v11.s}[0],[x5],x14 \n\t" // Store c104 into quad and increment by rs_c. +" st1 {v11.s}[1],[x5],x14 \n\t" // Store c105 into quad and increment by rs_c. +" st1 {v11.s}[2],[x5],x14 \n\t" // Store c106 into quad and increment by rs_c. +" st1 {v11.s}[3],[x5],x14 \n\t" // Store c107 into quad and increment by rs_c. +" \n\t" +" mov x5, x27 \n\t" +" \n\t" +" st1 {v12.s}[0],[x5],x14 \n\t" // Store c110 into quad and increment by rs_c. +" st1 {v12.s}[1],[x5],x14 \n\t" // Store c111 into quad and increment by rs_c. +" st1 {v12.s}[2],[x5],x14 \n\t" // Store c112 into quad and increment by rs_c. +" st1 {v12.s}[3],[x5],x14 \n\t" // Store c113 into quad and increment by rs_c. +" st1 {v13.s}[0],[x5],x14 \n\t" // Store c114 into quad and increment by rs_c. +" st1 {v13.s}[1],[x5],x14 \n\t" // Store c115 into quad and increment by rs_c. +" st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. +" st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. +" \n\t" +LABEL(SEND) // Done! " \n\t" :// output operands (none) :// input operands @@ -1066,13 +1067,11 @@ __asm__ volatile [a_next] "m" (a_next), // 9 [b_next] "m" (b_next) // 10 :// Register clobber list - "x0", "x1", "x2","x3","x4", - "x5", "x6", "x7", "x8", - "x9", "x10","x11","x12", - "x13","x14","x15", - "x16","x17","x18","x19", - "x20","x21","x22","x23", - "x24","x25","x26","x27", + "x0", "x1", "x2", + "x5", "x6", "x10","x14", + "x16","x17","x19","x20", + "x21","x22","x23","x24", + "x25","x26","x27", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10","v11", @@ -1134,20 +1133,14 @@ __asm__ volatile " ldr x1,%[baddr] \n\t" // Load address of B " ldr x2,%[caddr] \n\t" // Load address of C " \n\t" -" ldr x3,%[a_next] \n\t" // Move pointer -" ldr x4,%[b_next] \n\t" // Move pointer -" \n\t" " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) " ldr x6,%[k_left] \n\t" // Init guard (k_iter) " \n\t" -" ldr x7,%[alpha] \n\t" // Alpha address -" ldr x8,%[beta] \n\t" // Beta address -" \n\t" -" ldr x9,%[cs_c] \n\t" // Load cs_c -" lsl x10,x9,#3 \n\t" // cs_c * sizeof(double) +" ldr x10,%[cs_c] \n\t" // Load cs_c +" lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) " \n\t" -" ldr x13,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x13,#3 \n\t" // rs_c * sizeof(double). +" ldr x14,%[rs_c] \n\t" // Load rs_c. +" lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). " \n\t" " add x20,x2,x10 \n\t" //Load address Column 1 of C " add x21,x20,x10 \n\t" //Load address Column 2 of C @@ -1203,7 +1196,7 @@ __asm__ volatile " \n\t" " \n\t" " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -" beq .DCONSIDERKLEFT \n\t" +BEQ(DCONSIDERKLEFT) " \n\t" " ldr q0, [x0] \n\t" // Load a " ldr q1, [x0, #16] \n\t" @@ -1218,9 +1211,9 @@ __asm__ volatile " add x1, x1, #64 \n\t" //update address of B " \n\t" " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -" beq .DLASTITER \n\t" // (as loop is do-while-like). +BEQ(DLASTITER) // (as loop is do-while-like). " \n\t" -" DLOOP: \n\t" // Body +LABEL(DLOOP) // Body " \n\t" " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 @@ -1394,9 +1387,9 @@ __asm__ volatile " \n\t" " sub x5,x5,1 \n\t" // i-=1 " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -" bne DLOOP \n\t" +BNE(DLOOP) " \n\t" -".DLASTITER: \n\t" +LABEL(DLASTITER) " \n\t" " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate @@ -1554,11 +1547,11 @@ __asm__ volatile " \n\t" //End it 4 " add x0, x0, #144 \n\t" " \n\t" -" .DCONSIDERKLEFT: \n\t" +LABEL(DCONSIDERKLEFT) " cmp x6,0 \n\t" // If k_left == 0, we are done. -" beq .DPOSTACCUM \n\t" // else, we enter the k_left loop. +BEQ(DPOSTACCUM) // else, we enter the k_left loop. " \n\t" -".DLOOPKLEFT: \n\t" +LABEL(DLOOPKLEFT) " \n\t" " ldr q0, [x0],#16 \n\t" " ldr q1, [x0],#16 \n\t" // Load a @@ -1605,17 +1598,23 @@ __asm__ volatile " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate " \n\t" " cmp x6,0 \n\t" // Iterate again. -" bne .DLOOPKLEFT \n\t" // if i!=0. +BNE(DLOOPKLEFT) // if i!=0. +" \n\t" +LABEL(DPOSTACCUM) " \n\t" -" .DPOSTACCUM: \n\t" +" ldr x0,%[alpha] \n\t" // Alpha address +" ldr x1,%[beta] \n\t" // Beta address +" \n\t" +" ld1r {v6.2d},[x0] \n\t" // Load alpha. +" ld1r {v7.2d},[x1] \n\t" // Load beta " \n\t" -" ld1r {v6.2d},[x7] \n\t" // Load alpha. -" ld1r {v7.2d},[x8] \n\t" // Load beta +" ldr x0,%[a_next] \n\t" // Next A address for later use. +" ldr x1,%[b_next] \n\t" // Next B address for later use. " \n\t" -" cmp x13,#1 \n\t" // If rs_c != 1 (column-major) -" bne .DGENSTORED \n\t" +" cmp x14,#8 \n\t" // If rs_c != 1 (column-major) +BNE(DGENSTORED) " \n\t" -" .DCOLSTORED: \n\t" // C is column-major. +LABEL(DCOLSTORED) // C is column-major. " \n\t" " dup v0.2d, xzr \n\t" " dup v1.2d, xzr \n\t" @@ -1625,7 +1624,7 @@ __asm__ volatile " dup v5.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS1 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. " \n\t" " ldr q0, [x2] \n\t" //Load column 0 of C " ldr q1, [x2, #16] \n\t" @@ -1642,7 +1641,7 @@ __asm__ volatile " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROCOLSTOREDS1: \n\t" +LABEL(DBETAZEROCOLSTOREDS1) " \n\t" " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha @@ -1667,7 +1666,7 @@ __asm__ volatile " dup v13.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS2 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. " \n\t" " ldr q8, [x21] \n\t" //Load column 2 of C " ldr q9, [x21, #16] \n\t" @@ -1684,7 +1683,7 @@ __asm__ volatile " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROCOLSTOREDS2: \n\t" +LABEL(DBETAZEROCOLSTOREDS2) " \n\t" " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha @@ -1709,7 +1708,7 @@ __asm__ volatile " dup v5.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS3 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. " \n\t" " ldr q0, [x23] \n\t" //Load column 4 of C " ldr q1, [x23, #16] \n\t" @@ -1726,7 +1725,7 @@ __asm__ volatile " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROCOLSTOREDS3: \n\t" +LABEL(DBETAZEROCOLSTOREDS3) " \n\t" " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha @@ -1751,7 +1750,7 @@ __asm__ volatile " dup v13.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROCOLSTOREDS4 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. " \n\t" " ldr q8, [x25] \n\t" //Load column 6 of C " ldr q9, [x25, #16] \n\t" @@ -1768,10 +1767,10 @@ __asm__ volatile " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROCOLSTOREDS4: \n\t" +LABEL(DBETAZEROCOLSTOREDS4) " \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" +" prfm pldl2keep,[x0] \n\t" +" prfm pldl2keep,[x1] \n\t" " \n\t" " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha @@ -1788,9 +1787,9 @@ __asm__ volatile " str q12, [x26, #16] \n\t" " str q13, [x26, #32] \n\t" " \n\t" -" b .DEND \n\t" +BRANCH(DEND) " \n\t" -" .DGENSTORED: \n\t" // C is general-stride stored. +LABEL(DGENSTORED) // C is general-stride stored. " \n\t" " dup v0.2d, xzr \n\t" " dup v1.2d, xzr \n\t" @@ -1800,7 +1799,7 @@ __asm__ volatile " dup v5.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS1 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. " \n\t" " mov x27, x2 \n\t" " \n\t" // Load address of C. @@ -1827,7 +1826,7 @@ __asm__ volatile " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROGENSTOREDS1: \n\t" +LABEL(DBETAZEROGENSTOREDS1) " \n\t" " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha @@ -1862,7 +1861,7 @@ __asm__ volatile " dup v13.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS2 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. " \n\t" " mov x27, x21 \n\t" // Load address of C. " \n\t" @@ -1889,7 +1888,7 @@ __asm__ volatile " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROGENSTOREDS2: \n\t" +LABEL(DBETAZEROGENSTOREDS2) " \n\t" " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha @@ -1924,7 +1923,7 @@ __asm__ volatile " dup v5.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS3 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. " \n\t" " mov x27, x23 \n\t" // Load address of C. " \n\t" @@ -1951,7 +1950,7 @@ __asm__ volatile " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROGENSTOREDS3: \n\t" +LABEL(DBETAZEROGENSTOREDS3) " \n\t" " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha @@ -1986,7 +1985,7 @@ __asm__ volatile " dup v13.2d, xzr \n\t" " \n\t" " fcmp d7,#0.0 \n\t" -" beq .DBETAZEROGENSTOREDS4 \n\t" // Taking care of the beta==0 case. +BEQ(DBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. " \n\t" " mov x27, x25 \n\t" " \n\t" @@ -2013,10 +2012,10 @@ __asm__ volatile " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta " \n\t" -" .DBETAZEROGENSTOREDS4: \n\t" +LABEL(DBETAZEROGENSTOREDS4) " \n\t" -" prfm pldl2keep,[x3] \n\t" -" prfm pldl2keep,[x4] \n\t" +" prfm pldl2keep,[x0] \n\t" +" prfm pldl2keep,[x1] \n\t" " \n\t" " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha @@ -2043,7 +2042,7 @@ __asm__ volatile " st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. " st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. " \n\t" -" .DEND: \n\t" // Done! +LABEL(DEND) // Done! " \n\t" :// output operands (none) :// input operands @@ -2059,12 +2058,10 @@ __asm__ volatile [a_next] "m" (a_next), // 8 [b_next] "m" (b_next) // 9 :// Register clobber list - "x0","x1","x2","x3", - "x4","x5","x6", - "x7","x8","x9", - "x10","x11","x12","x13","x14","x16","x17", - "x20","x21","x22","x23","x24","x25","x26", - "x27", + "x0","x1","x2", + "x5","x6","x10", + "x14","x16","x17", + "x20","x21","x22","x23","x24","x25","x26","x27", "v0","v1","v2", "v3","v4","v5", "v6","v7","v8", diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 403aaaaeef..8d0060b2f5 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -760,7 +761,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -1857,7 +1859,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", - "memory" + "xmm0", "xmm2", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) } @@ -2530,7 +2533,8 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", - "memory" + "xmm0", "xmm2", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) } diff --git a/kernels/haswell/1m/CMakeLists.txt b/kernels/haswell/1m/CMakeLists.txt index 56abd13aec..9130e97f15 100644 --- a/kernels/haswell/1m/CMakeLists.txt +++ b/kernels/haswell/1m/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## add_library(haswell_1m OBJECT diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index 78e76589dc..255759aab4 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,7 +107,7 @@ void bli_cpackm_haswell_asm_3xk if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() - + mov(var(a), rax) // load address of a. mov(var(inca), r8) // load inca @@ -122,14 +122,14 @@ void bli_cpackm_haswell_asm_3xk mov(var(one), rdx) // load address of 1.0 constant vbroadcastss(mem(rdx, 0), ymm1) // load 1.0 and duplicate vxorps(ymm0, ymm0, ymm0) // set ymm0 to 0.0. - + mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate - + // now branch on kappa == 1.0 - + vucomiss(xmm1, xmm10) // set ZF if kappa_r == 1.0. sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); vucomiss(xmm0, xmm11) // set ZF if kappa_i == 0.0. @@ -143,7 +143,7 @@ void bli_cpackm_haswell_asm_3xk cmp(imm(8), r8) // set ZF if (8*inca) == 8. jz(.CCOLNONU) // jump to column storage case - + // -- kappa non-unit, row storage on A ------------------------------------- label(.CROWNONU) @@ -156,7 +156,7 @@ void bli_cpackm_haswell_asm_3xk label(.CCOLNONU) jmp(.CDONE) // jump to end. - + @@ -167,7 +167,7 @@ void bli_cpackm_haswell_asm_3xk // -- kappa unit, row storage on A ----------------------------------------- - + label(.CROWUNIT) //lea(mem(r8, r8, 2), r12) // r12 = 3*inca @@ -251,7 +251,7 @@ void bli_cpackm_haswell_asm_3xk // -- kappa unit, column storage on A -------------------------------------- label(.CCOLUNIT) - + lea(mem(r10, r10, 2), r13) // r13 = 3*lda mov(var(k_iter), rsi) // i = k_iter; @@ -315,8 +315,8 @@ void bli_cpackm_haswell_asm_3xk label(.CDONE) - - + + end_asm( : // output operands (none) @@ -372,7 +372,7 @@ void bli_cpackm_haswell_asm_3xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } @@ -392,7 +392,7 @@ void bli_cpackm_haswell_asm_3xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 61ace6945d..39939bf407 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,7 +107,7 @@ void bli_cpackm_haswell_asm_8xk if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() - + mov(var(a), rax) // load address of a. mov(var(inca), r8) // load inca @@ -122,14 +122,14 @@ void bli_cpackm_haswell_asm_8xk mov(var(one), rdx) // load address of 1.0 constant vbroadcastss(mem(rdx, 0), ymm1) // load 1.0 and duplicate vxorps(ymm0, ymm0, ymm0) // set ymm0 to 0.0. - + mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate - + // now branch on kappa == 1.0 - + vucomiss(xmm1, xmm10) // set ZF if kappa_r == 1.0. sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); vucomiss(xmm0, xmm11) // set ZF if kappa_i == 0.0. @@ -143,7 +143,7 @@ void bli_cpackm_haswell_asm_8xk cmp(imm(8), r8) // set ZF if (8*inca) == 8. jz(.CCOLNONU) // jump to column storage case - + // -- kappa non-unit, row storage on A ------------------------------------- label(.CROWNONU) @@ -156,7 +156,7 @@ void bli_cpackm_haswell_asm_8xk label(.CCOLNONU) jmp(.CDONE) // jump to end. - + @@ -167,7 +167,7 @@ void bli_cpackm_haswell_asm_8xk // -- kappa unit, row storage on A ----------------------------------------- - + label(.CROWUNIT) lea(mem(r8, r8, 2), r12) // r12 = 3*inca @@ -271,7 +271,7 @@ void bli_cpackm_haswell_asm_8xk // -- kappa unit, column storage on A -------------------------------------- label(.CCOLUNIT) - + lea(mem(r10, r10, 2), r13) // r13 = 3*lda mov(var(k_iter), rsi) // i = k_iter; @@ -335,8 +335,8 @@ void bli_cpackm_haswell_asm_8xk label(.CDONE) - - + + end_asm( : // output operands (none) @@ -392,7 +392,7 @@ void bli_cpackm_haswell_asm_8xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } @@ -410,7 +410,7 @@ void bli_cpackm_haswell_asm_8xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c index e2982dbfeb..f45e24a64e 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d6xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c index e3b00a71e7..b52c89faa1 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c index b049fcdb5c..1282c5ae17 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_s16xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c index c05c36b66f..d5d8f52998 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_s6xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c index cb025c1f01..2d3c243c08 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,7 +107,7 @@ void bli_zpackm_haswell_asm_3xk if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() - + mov(var(a), rax) // load address of a. mov(var(inca), r8) // load inca @@ -124,14 +124,14 @@ void bli_zpackm_haswell_asm_3xk mov(var(one), rdx) // load address of 1.0 constant vbroadcastsd(mem(rdx, 0), ymm1) // load 1.0 and duplicate vxorpd(ymm0, ymm0, ymm0) // set ymm0 to 0.0. - + mov(var(kappa), rcx) // load address of kappa vbroadcastsd(mem(rcx, 0), ymm10) // load kappa_r and duplicate vbroadcastsd(mem(rcx, 8), ymm11) // load kappa_i and duplicate - + // now branch on kappa == 1.0 - + vucomisd(xmm1, xmm10) // set ZF if kappa_r == 1.0. sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); vucomisd(xmm0, xmm11) // set ZF if kappa_i == 0.0. @@ -145,7 +145,7 @@ void bli_zpackm_haswell_asm_3xk cmp(imm(16), r8) // set ZF if (16*inca) == 16. jz(.ZCOLNONU) // jump to column storage case - + // -- kappa non-unit, row storage on A ------------------------------------- label(.ZROWNONU) @@ -158,7 +158,7 @@ void bli_zpackm_haswell_asm_3xk label(.ZCOLNONU) jmp(.ZDONE) // jump to end. - + @@ -169,7 +169,7 @@ void bli_zpackm_haswell_asm_3xk // -- kappa unit, row storage on A ----------------------------------------- - + label(.ZROWUNIT) //lea(mem(r8, r8, 2), r12) // r12 = 3*inca @@ -257,7 +257,7 @@ void bli_zpackm_haswell_asm_3xk // -- kappa unit, column storage on A -------------------------------------- label(.ZCOLUNIT) - + lea(mem(r10, r10, 2), r13) // r13 = 3*lda mov(var(k_iter), rsi) // i = k_iter; @@ -321,8 +321,8 @@ void bli_zpackm_haswell_asm_3xk label(.ZDONE) - - + + end_asm( : // output operands (none) @@ -378,7 +378,7 @@ void bli_zpackm_haswell_asm_3xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } @@ -396,7 +396,7 @@ void bli_zpackm_haswell_asm_3xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c index e407fedf9f..d663ccc9bc 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,7 +107,7 @@ void bli_zpackm_haswell_asm_4xk if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() - + mov(var(a), rax) // load address of a. mov(var(inca), r8) // load inca @@ -128,10 +128,10 @@ void bli_zpackm_haswell_asm_4xk mov(var(kappa), rcx) // load address of kappa vbroadcastsd(mem(rcx, 0), ymm10) // load kappa_r and duplicate vbroadcastsd(mem(rcx, 8), ymm11) // load kappa_i and duplicate - + // now branch on kappa == 1.0 - + vucomisd(xmm1, xmm10) // set ZF if kappa_r == 1.0. sete(r12b) // r12b = ( ZF == 1 ? 1 : 0 ); vucomisd(xmm0, xmm11) // set ZF if kappa_i == 0.0. @@ -145,7 +145,7 @@ void bli_zpackm_haswell_asm_4xk cmp(imm(16), r8) // set ZF if (16*inca) == 16. jz(.ZCOLNONU) // jump to column storage case - + // -- kappa non-unit, row storage on A ------------------------------------- label(.ZROWNONU) @@ -158,7 +158,7 @@ void bli_zpackm_haswell_asm_4xk label(.ZCOLNONU) jmp(.ZDONE) // jump to end. - + @@ -169,7 +169,7 @@ void bli_zpackm_haswell_asm_4xk // -- kappa unit, row storage on A ----------------------------------------- - + label(.ZROWUNIT) lea(mem(r8, r8, 2), r12) // r12 = 3*inca @@ -267,7 +267,7 @@ void bli_zpackm_haswell_asm_4xk // -- kappa unit, column storage on A -------------------------------------- label(.ZCOLUNIT) - + lea(mem(r10, r10, 2), r13) // r13 = 3*lda mov(var(k_iter), rsi) // i = k_iter; @@ -331,8 +331,8 @@ void bli_zpackm_haswell_asm_4xk label(.ZDONE) - - + + end_asm( : // output operands (none) @@ -388,7 +388,7 @@ void bli_zpackm_haswell_asm_4xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } @@ -406,7 +406,7 @@ void bli_zpackm_haswell_asm_4xk ( m_edge, n_edge, - p_edge, 1, ldp + p_edge, 1, ldp ); } } diff --git a/kernels/haswell/3/CMakeLists.txt b/kernels/haswell/3/CMakeLists.txt deleted file mode 100644 index a42bdadf83..0000000000 --- a/kernels/haswell/3/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(haswell_3 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_haswell_asm_d6x8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_haswell_asm_d8x6.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_haswell_asm_d6x8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_haswell_asm_d6x8.c - ) - -target_compile_options(haswell_3 PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(haswell_3 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() - -add_subdirectory(sup) diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index f0a8fe34c3..7a3478cb29 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,7 +101,7 @@ void bli_sgemm_haswell_asm_6x16 uint64_t cs_c = cs_c0; begin_asm() - + //vzeroall() // zero all xmm/ymm registers. vxorps( ymm4, ymm4, ymm4) vmovaps( ymm4, ymm5) @@ -115,21 +115,21 @@ void bli_sgemm_haswell_asm_6x16 vmovaps( ymm4, ymm13) vmovaps( ymm4, ymm14) vmovaps( ymm4, ymm15) - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c @@ -138,19 +138,19 @@ void bli_sgemm_haswell_asm_6x16 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*4)) @@ -160,24 +160,24 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - + // iteration 1 prefetch(0, mem(rax, 72*4)) @@ -187,51 +187,51 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 80*4)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) @@ -239,91 +239,91 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*6*4), rax) // a += 4*6 (unroll x mr) add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*6*4), rax) // a += 1*6 (unroll x mr) add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - + vmulps(ymm0, ymm4, ymm4) // scale by alpha vmulps(ymm0, ymm5, ymm5) vmulps(ymm0, ymm6, ymm6) @@ -336,222 +336,222 @@ void bli_sgemm_haswell_asm_6x16 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) - + lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. jz(.SROWSTORED) // jump to row storage case - - + + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. jz(.SCOLSTORED) // jump to column storage case - - - + + + label(.SGENSTORED) - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm4, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm6, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm8, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm10, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm12, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm14, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ //add(rdi, rcx) // c += rs_c; - - + + mov(rdx, rcx) // rcx = c + 8*cs_c - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm5, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm7, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm9, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm11, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm13, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + SGEMM_INPUT_GS_BETA_NZ vfmadd213ps(ymm15, ymm3, ymm0) SGEMM_OUTPUT_GS_BETA_NZ //add(rdi, rcx) // c += rs_c; - - - + + + jmp(.SDONE) // jump to end. - - - + + + label(.SROWSTORED) - - + + vfmadd231ps(mem(rcx), ymm3, ymm4) vmovups(ymm4, mem(rcx)) add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm5) vmovups(ymm5, mem(rdx)) add(rdi, rdx) - - + + vfmadd231ps(mem(rcx), ymm3, ymm6) vmovups(ymm6, mem(rcx)) add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm7) vmovups(ymm7, mem(rdx)) add(rdi, rdx) - - + + vfmadd231ps(mem(rcx), ymm3, ymm8) vmovups(ymm8, mem(rcx)) add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm9) vmovups(ymm9, mem(rdx)) add(rdi, rdx) - - + + vfmadd231ps(mem(rcx), ymm3, ymm10) vmovups(ymm10, mem(rcx)) add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm11) vmovups(ymm11, mem(rdx)) add(rdi, rdx) - - + + vfmadd231ps(mem(rcx), ymm3, ymm12) vmovups(ymm12, mem(rcx)) add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm13) vmovups(ymm13, mem(rdx)) add(rdi, rdx) - - + + vfmadd231ps(mem(rcx), ymm3, ymm14) vmovups(ymm14, mem(rcx)) //add(rdi, rcx) vfmadd231ps(mem(rdx), ymm3, ymm15) vmovups(ymm15, mem(rdx)) //add(rdi, rdx) - - - + + + jmp(.SDONE) // jump to end. - - - + + + label(.SCOLSTORED) - - + + vbroadcastss(mem(rbx), ymm3) - + vunpcklps(ymm6, ymm4, ymm0) vunpcklps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vfmadd231ps(mem(rcx), xmm3, xmm0) vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - + vextractf128(imm(0x1), ymm1, xmm2) vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm6, ymm4, ymm0) vunpckhps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - + vextractf128(imm(0x1), ymm1, xmm2) vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm14, ymm12, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(mem(r14), xmm1, xmm1) @@ -564,7 +564,7 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - + vunpckhps(ymm14, ymm12, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) @@ -577,50 +577,50 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - + + + vunpcklps(ymm7, ymm5, ymm0) vunpcklps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vfmadd231ps(mem(rcx), xmm3, xmm0) vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - + vextractf128(imm(0x1), ymm1, xmm2) vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm7, ymm5, ymm0) vunpckhps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - + vextractf128(imm(0x1), ymm1, xmm2) vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm15, ymm13, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(mem(r14), xmm1, xmm1) @@ -633,7 +633,7 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - + vunpckhps(ymm15, ymm13, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) @@ -646,262 +646,264 @@ void bli_sgemm_haswell_asm_6x16 vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - + + + jmp(.SDONE) // jump to end. - - - + + + label(.SBETAZERO) - + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. jz(.SROWSTORBZ) // jump to row storage case - + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. jz(.SCOLSTORBZ) // jump to column storage case - - - + + + label(.SGENSTORBZ) - - + + vmovaps(ymm4, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm6, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm8, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm10, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm12, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm14, ymm0) SGEMM_OUTPUT_GS_BETA_NZ //add(rdi, rcx) // c += rs_c; - - + + mov(rdx, rcx) // rcx = c + 8*cs_c - - + + vmovaps(ymm5, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm7, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm9, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm11, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm13, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovaps(ymm15, ymm0) SGEMM_OUTPUT_GS_BETA_NZ //add(rdi, rcx) // c += rs_c; - - - + + + jmp(.SDONE) // jump to end. - - - + + + label(.SROWSTORBZ) - - + + vmovups(ymm4, mem(rcx)) add(rdi, rcx) vmovups(ymm5, mem(rdx)) add(rdi, rdx) - + vmovups(ymm6, mem(rcx)) add(rdi, rcx) vmovups(ymm7, mem(rdx)) add(rdi, rdx) - - + + vmovups(ymm8, mem(rcx)) add(rdi, rcx) vmovups(ymm9, mem(rdx)) add(rdi, rdx) - - + + vmovups(ymm10, mem(rcx)) add(rdi, rcx) vmovups(ymm11, mem(rdx)) add(rdi, rdx) - - + + vmovups(ymm12, mem(rcx)) add(rdi, rcx) vmovups(ymm13, mem(rdx)) add(rdi, rdx) - - + + vmovups(ymm14, mem(rcx)) //add(rdi, rcx) vmovups(ymm15, mem(rdx)) //add(rdi, rdx) - - - + + + jmp(.SDONE) // jump to end. - - - + + + label(.SCOLSTORBZ) - - + + vunpcklps(ymm6, ymm4, ymm0) vunpcklps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - + vextractf128(imm(0x1), ymm1, xmm2) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm6, ymm4, ymm0) vunpckhps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - + vextractf128(imm(0x1), ymm1, xmm2) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm14, ymm12, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - + vunpckhps(ymm14, ymm12, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - + + + vunpcklps(ymm7, ymm5, ymm0) vunpcklps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - + vextractf128(imm(0x1), ymm1, xmm2) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm7, ymm5, ymm0) vunpckhps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - + vextractf128(imm(0x1), ymm1, xmm2) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm15, ymm13, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - + vunpckhps(ymm15, ymm13, ymm0) vextractf128(imm(0x1), ymm0, xmm2) vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - - + + + + label(.SDONE) - + + vzeroupper() - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -922,25 +924,15 @@ void bli_sgemm_haswell_asm_6x16 vmovlpd(mem(rcx), xmm0, xmm0) \ vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) \ - vmovhpd(mem(rcx, r13, 1), xmm1, xmm1) \ - vperm2f128(imm(0x20), ymm1, ymm0, ymm0) /*\ - vmovlpd(mem(rcx, rsi, 4), xmm2, xmm2) \ - vmovhpd(mem(rcx, r15, 1), xmm2, xmm2) \ - vmovlpd(mem(rcx, r13, 2), xmm1, xmm1) \ - vmovhpd(mem(rcx, r10, 1), xmm1, xmm1) \ - vperm2f128(imm(0x20), ymm1, ymm2, ymm2)*/ + vmovhpd(mem(rcx, r8, 1), xmm1, xmm1) \ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0) #define DGEMM_OUTPUT_GS_BETA_NZ \ vextractf128(imm(1), ymm0, xmm1) \ vmovlpd(xmm0, mem(rcx)) \ vmovhpd(xmm0, mem(rcx, rsi, 1)) \ vmovlpd(xmm1, mem(rcx, rsi, 2)) \ - vmovhpd(xmm1, mem(rcx, r13, 1)) /*\ - vextractf128(imm(1), ymm2, xmm1) \ - vmovlpd(xmm2, mem(rcx, rsi, 4)) \ - vmovhpd(xmm2, mem(rcx, r15, 1)) \ - vmovlpd(xmm1, mem(rcx, r13, 2)) \ - vmovhpd(xmm1, mem(rcx, r10, 1))*/ + vmovhpd(xmm1, mem(rcx, r8, 1)) void bli_dgemm_haswell_asm_6x8 ( @@ -962,11 +954,18 @@ void bli_dgemm_haswell_asm_6x8 // different size than is expected by load instructions. uint64_t k_iter = (uint64_t)k0/4; uint64_t k_left = (uint64_t)k0%4; + uint64_t prefetch_iters = 30; + if ( k_iter > prefetch_iters ) { + k_iter -= prefetch_iters; + } else { + prefetch_iters = k_iter; + k_iter = 0; + } uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; begin_asm() - + //vzeroall() // zero all xmm/ymm registers. vxorpd( ymm4, ymm4, ymm4) // vzeroall is expensive @@ -982,61 +981,58 @@ void bli_dgemm_haswell_asm_6x8 vmovapd( ymm4, ymm14) vmovapd( ymm4, ymm15) - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. - //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - - - mov(var(k_iter), rsi) // i = k_iter; + + lea(mem(rdi, rdi, 2), r10) // r10 = 3*rs_c; + lea(mem(rcx, r10, 1), rdx) // rdx = c + 3*rs_c; + + + + + mov(var(k_pref), r8) // i = k_iter after prefetch + mov(var(k_iter), rsi) // i = k_iter before prefetch test(rsi, rsi) // check i via logical AND. - je(.DCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - + je(.DPOSTMAINLOOP) // if i == 0, jump to code that + // prefetches, followed by any post-prefetch iters + // and the k-left loop + + + align32 + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 - prefetch(0, mem(rax, 64*8)) - - vbroadcastsd(mem(rax, 0*8), ymm2) - vbroadcastsd(mem(rax, 1*8), ymm3) + + vbroadcastsd(mem(rax, 0*8-128), ymm2) + vbroadcastsd(mem(rax, 1*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 2*8), ymm2) - vbroadcastsd(mem(rax, 3*8), ymm3) + vbroadcastsd(mem(rax, 2*8-128), ymm2) + vbroadcastsd(mem(rax, 3*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, 4*8), ymm2) - vbroadcastsd(mem(rax, 5*8), ymm3) + vbroadcastsd(mem(rax, 4*8-128), ymm2) + vbroadcastsd(mem(rax, 5*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) @@ -1046,24 +1042,22 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, -1*32), ymm1) // iteration 1 - prefetch(0, mem(rax, 72*8)) - - vbroadcastsd(mem(rax, 6*8), ymm2) - vbroadcastsd(mem(rax, 7*8), ymm3) + vbroadcastsd(mem(rax, 6*8-128), ymm2) + vbroadcastsd(mem(rax, 7*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 8*8), ymm2) - vbroadcastsd(mem(rax, 9*8), ymm3) + vbroadcastsd(mem(rax, 8*8-128), ymm2) + vbroadcastsd(mem(rax, 9*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, 10*8), ymm2) - vbroadcastsd(mem(rax, 11*8), ymm3) + vbroadcastsd(mem(rax, 10*8-128), ymm2) + vbroadcastsd(mem(rax, 11*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) @@ -1073,24 +1067,22 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 80*8)) - - vbroadcastsd(mem(rax, 12*8), ymm2) - vbroadcastsd(mem(rax, 13*8), ymm3) + vbroadcastsd(mem(rax, 12*8-128), ymm2) + vbroadcastsd(mem(rax, 13*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 14*8), ymm2) - vbroadcastsd(mem(rax, 15*8), ymm3) + vbroadcastsd(mem(rax, 14*8-128), ymm2) + vbroadcastsd(mem(rax, 15*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, 16*8), ymm2) - vbroadcastsd(mem(rax, 17*8), ymm3) + vbroadcastsd(mem(rax, 16*8-128), ymm2) + vbroadcastsd(mem(rax, 17*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) @@ -1100,22 +1092,22 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 3*32), ymm1) // iteration 3 - vbroadcastsd(mem(rax, 18*8), ymm2) - vbroadcastsd(mem(rax, 19*8), ymm3) + vbroadcastsd(mem(rax, 18*8-128), ymm2) + vbroadcastsd(mem(rax, 19*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 20*8), ymm2) - vbroadcastsd(mem(rax, 21*8), ymm3) + vbroadcastsd(mem(rax, 20*8-128), ymm2) + vbroadcastsd(mem(rax, 21*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, 22*8), ymm2) - vbroadcastsd(mem(rax, 23*8), ymm3) + vbroadcastsd(mem(rax, 22*8-128), ymm2) + vbroadcastsd(mem(rax, 23*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) @@ -1132,10 +1124,28 @@ void bli_dgemm_haswell_asm_6x8 jne(.DLOOPKITER) // iterate again if i != 0. + test(r8, r8) // If no post-prefetch iters to do, skip to kleft + je(.DCONSIDKLEFT) + label(.DPOSTMAINLOOP) + /* Prefetch C */ + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + + mov(r8, rsi) // i = k_iter after prefetch + xor(r8, r8) // Zero out r8, so we don't prefetch again + test(rsi, rsi) // check i via logical AND. + jne(.DLOOPKITER) + + // All unrolled iters (and prefetches) done label(.DCONSIDKLEFT) mov(var(k_left), rsi) // i = k_left; @@ -1146,24 +1156,22 @@ void bli_dgemm_haswell_asm_6x8 label(.DLOOPKLEFT) // EDGE LOOP - prefetch(0, mem(rax, 64*8)) - - vbroadcastsd(mem(rax, 0*8), ymm2) - vbroadcastsd(mem(rax, 1*8), ymm3) + vbroadcastsd(mem(rax, 0*8-128), ymm2) + vbroadcastsd(mem(rax, 1*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 2*8), ymm2) - vbroadcastsd(mem(rax, 3*8), ymm3) + vbroadcastsd(mem(rax, 2*8-128), ymm2) + vbroadcastsd(mem(rax, 3*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, 4*8), ymm2) - vbroadcastsd(mem(rax, 5*8), ymm3) + vbroadcastsd(mem(rax, 4*8-128), ymm2) + vbroadcastsd(mem(rax, 5*8-128), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) @@ -1213,11 +1221,9 @@ void bli_dgemm_haswell_asm_6x8 lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; - lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; + lea(mem(rcx, rdi, 4), r9) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; - //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; + lea(mem(rsi, rsi, 2), r8) // r8 = 3*cs_c; // now avoid loading C if beta == 0 @@ -1321,51 +1327,33 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rdi, 1), ymm3, ymm6) + vfmadd231pd(mem(rdx, rdi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rdi, 2), ymm3, ymm8) + vfmadd231pd(mem(rdx, rdi, 2), ymm3, ymm9) + vmovupd(ymm4, mem(rcx)) vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) + vmovupd(ymm6, mem(rcx, rdi, 1)) + vmovupd(ymm7, mem(rdx, rdi, 1)) + vmovupd(ymm8, mem(rcx, rdi, 2)) + vmovupd(ymm9, mem(rdx, rdi, 2)) + add(r10, rcx) // r10 = 3 * rdi + add(r10, rdx) vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm11) + vfmadd231pd(mem(rcx, rdi, 1), ymm3, ymm12) + vfmadd231pd(mem(rdx, rdi, 1), ymm3, ymm13) + vfmadd231pd(mem(rcx, rdi, 2), ymm3, ymm14) + vfmadd231pd(mem(rdx, rdi, 2), ymm3, ymm15) + vmovupd(ymm10, mem(rcx)) vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) + vmovupd(ymm12, mem(rcx, rdi, 1)) + vmovupd(ymm13, mem(rdx, rdi, 1)) + vmovupd(ymm14, mem(rcx, rdi, 2)) + vmovupd(ymm15, mem(rdx, rdi, 2)) @@ -1390,11 +1378,11 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(mem(rcx), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) + vfmadd231pd(mem(rcx, r8, 1), ymm3, ymm10) vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) + vmovupd(ymm10, mem(rcx, r8, 1)) lea(mem(rcx, rsi, 4), rcx) @@ -1403,16 +1391,16 @@ void bli_dgemm_haswell_asm_6x8 vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vfmadd231pd(mem(r9), xmm3, xmm0) + vfmadd231pd(mem(r9, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r9, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r9, r8, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r9)) + vmovupd(xmm1, mem(r9, rsi, 1)) + vmovupd(xmm2, mem(r9, rsi, 2)) + vmovupd(xmm4, mem(r9, r8, 1)) - lea(mem(r14, rsi, 4), r14) + lea(mem(r9, rsi, 4), r9) vunpcklpd(ymm7, ymm5, ymm0) @@ -1429,11 +1417,11 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(mem(rcx), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) + vfmadd231pd(mem(rcx, r8, 1), ymm3, ymm11) vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) + vmovupd(ymm11, mem(rcx, r8, 1)) //lea(mem(rcx, rsi, 4), rcx) @@ -1442,16 +1430,16 @@ void bli_dgemm_haswell_asm_6x8 vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vfmadd231pd(mem(r9), xmm3, xmm0) + vfmadd231pd(mem(r9, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r9, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r9, r8, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r9)) + vmovupd(xmm1, mem(r9, rsi, 1)) + vmovupd(xmm2, mem(r9, rsi, 2)) + vmovupd(xmm4, mem(r9, r8, 1)) - //lea(mem(r14, rsi, 4), r14) + //lea(mem(r9, rsi, 4), r9) @@ -1542,38 +1530,21 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) + vmovupd(ymm6, mem(rcx, rdi, 1)) + vmovupd(ymm7, mem(rdx, rdi, 1)) + vmovupd(ymm8, mem(rcx, rdi, 2)) + vmovupd(ymm9, mem(rdx, rdi, 2)) + add(r10, rcx) + add(r10, rdx) vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) + vmovupd(ymm12, mem(rcx, rdi, 1)) + vmovupd(ymm13, mem(rdx, rdi, 1)) + vmovupd(ymm14, mem(rcx, rdi, 2)) + vmovupd(ymm15, mem(rdx, rdi, 2)) jmp(.DDONE) // jump to end. @@ -1595,7 +1566,7 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) + vmovupd(ymm10, mem(rcx, r8, 1)) lea(mem(rcx, rsi, 4), rcx) @@ -1604,12 +1575,12 @@ void bli_dgemm_haswell_asm_6x8 vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vmovupd(xmm0, mem(r9)) + vmovupd(xmm1, mem(r9, rsi, 1)) + vmovupd(xmm2, mem(r9, rsi, 2)) + vmovupd(xmm4, mem(r9, r8, 1)) - lea(mem(r14, rsi, 4), r14) + lea(mem(r9, rsi, 4), r9) vunpcklpd(ymm7, ymm5, ymm0) @@ -1624,7 +1595,7 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) + vmovupd(ymm11, mem(rcx, r8, 1)) //lea(mem(rcx, rsi, 4), rcx) @@ -1633,20 +1604,24 @@ void bli_dgemm_haswell_asm_6x8 vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vmovupd(xmm0, mem(r9)) + vmovupd(xmm1, mem(r9, rsi, 1)) + vmovupd(xmm2, mem(r9, rsi, 2)) + vmovupd(xmm4, mem(r9, r8, 1)) + + //lea(mem(r9, rsi, 4), r9) - //lea(mem(r14, rsi, 4), r14) label(.DDONE) + + vzeroupper() - end_asm( + + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -1657,12 +1632,11 @@ void bli_dgemm_haswell_asm_6x8 [beta] "m" (beta), // 5 [c] "m" (c), // 6 [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [cs_c] "m" (cs_c), // 8 + [k_pref] "m" (prefetch_iters) // 9 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", @@ -1705,7 +1679,7 @@ void bli_dgemm_haswell_asm_6x8 vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) - + #define CGEMM_OUTPUT_RS \ vmovups(ymm0, mem(rcx)) \ @@ -1732,69 +1706,69 @@ void bli_cgemm_haswell_asm_3x8 uint64_t cs_c = cs_c0; begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(scomplex) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*rs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*rs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*rs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) @@ -1802,51 +1776,51 @@ void bli_cgemm_haswell_asm_3x8 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*8)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) @@ -1854,84 +1828,84 @@ void bli_cgemm_haswell_asm_3x8 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*3*8), rax) // a += 4*3 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*3*8), rax) // a += 1*3 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - - + + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) @@ -1940,76 +1914,76 @@ void bli_cgemm_haswell_asm_3x8 vpermilps(imm(0xb1), ymm11, ymm11) vpermilps(imm(0xb1), ymm14, ymm14) vpermilps(imm(0xb1), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) - + vaddsubps(ymm10, ymm8, ymm8) vaddsubps(ymm11, ymm9, ymm9) - + vaddsubps(ymm14, ymm12, ymm12) vaddsubps(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate - - + + vpermilps(imm(0xb1), ymm4, ymm3) vmulps(ymm0, ymm4, ymm4) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm4, ymm4) - + vpermilps(imm(0xb1), ymm5, ymm3) vmulps(ymm0, ymm5, ymm5) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm5, ymm5) - - + + vpermilps(imm(0xb1), ymm8, ymm3) vmulps(ymm0, ymm8, ymm8) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm9, ymm3) vmulps(ymm0, ymm9, ymm9) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm9, ymm9) - - + + vpermilps(imm(0xb1), ymm12, ymm3) vmulps(ymm0, ymm12, ymm12) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm13, ymm3) vmulps(ymm0, ymm13, ymm13) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - - + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(scomplex) lea(mem(, rsi, 4), rdx) // rdx = 4*cs_c; lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. @@ -2018,186 +1992,187 @@ void bli_cgemm_haswell_asm_3x8 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.CROWSTORED) // jump to row storage case - - - + + + label(.CGENSTORED) - - + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm4, ymm0, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm5, ymm0, ymm0) CGEMM_OUTPUT_GS mov(r11, rcx) // rcx = c + 1*rs_c - - - + + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm8, ymm0, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm9, ymm0, ymm0) CGEMM_OUTPUT_GS mov(r12, rcx) // rcx = c + 2*rs_c - - - + + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm12, ymm0, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_GS_BETA_NZ vaddps(ymm13, ymm0, ymm0) CGEMM_OUTPUT_GS - - - + + + jmp(.CDONE) // jump to end. - - - + + + label(.CROWSTORED) - - + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm4, ymm0, ymm0) CGEMM_OUTPUT_RS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm5, ymm0, ymm0) CGEMM_OUTPUT_RS mov(r11, rcx) // rcx = c + 1*rs_c - - - + + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm8, ymm0, ymm0) CGEMM_OUTPUT_RS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm9, ymm0, ymm0) CGEMM_OUTPUT_RS mov(r12, rcx) // rcx = c + 2*rs_c - - - + + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm12, ymm0, ymm0) CGEMM_OUTPUT_RS add(rdx, rcx) // c += 4*cs_c; - - + + CGEMM_INPUT_SCALE_RS_BETA_NZ vaddps(ymm13, ymm0, ymm0) CGEMM_OUTPUT_RS - - - + + + jmp(.CDONE) // jump to end. - - - + + + label(.CBETAZERO) - + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.CROWSTORBZ) // jump to row storage case - - - + + + label(.CGENSTORBZ) - - + + vmovaps(ymm4, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovaps(ymm5, ymm0) CGEMM_OUTPUT_GS mov(r11, rcx) // rcx = c + 1*rs_c - - - + + + vmovaps(ymm8, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovaps(ymm9, ymm0) CGEMM_OUTPUT_GS mov(r12, rcx) // rcx = c + 2*rs_c - - - + + + vmovaps(ymm12, ymm0) CGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; - - + + vmovaps(ymm13, ymm0) CGEMM_OUTPUT_GS - - - + + + jmp(.CDONE) // jump to end. - - - + + + label(.CROWSTORBZ) - - + + vmovups(ymm4, mem(rcx)) vmovups(ymm5, mem(rcx, rdx, 1)) - + vmovups(ymm8, mem(r11)) vmovups(ymm9, mem(r11, rdx, 1)) - + vmovups(ymm12, mem(r12)) vmovups(ymm13, mem(r12, rdx, 1)) - - - - - - + + + + label(.CDONE) - + + vzeroupper() - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2223,7 +2198,7 @@ void bli_cgemm_haswell_asm_3x8 vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - + // assumes values to output are in ymm0 #define ZGEMM_OUTPUT_GS \ vextractf128(imm(1), ymm0, xmm3) \ @@ -2240,7 +2215,6 @@ void bli_cgemm_haswell_asm_3x8 #define ZGEMM_OUTPUT_RS \ vmovupd(ymm0, mem(rcx)) \ - void bli_zgemm_haswell_asm_3x4 ( dim_t k0, @@ -2286,6 +2260,7 @@ void bli_zgemm_haswell_asm_3x4 vzeroall() // zero all xmm/ymm registers. + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. @@ -2347,6 +2322,8 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, -1*32), ymm1) // iteration 1 + prefetch(0, mem(rax, 36*16)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -2372,7 +2349,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 38*16)) + prefetch(0, mem(rax, 40*16)) vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) @@ -2585,35 +2562,48 @@ void bli_zgemm_haswell_asm_3x4 jz(.ZROWSTORED) // jump to row storage case + label(.ZGENSTORED) + + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm5, ymm0, ymm0) ZGEMM_OUTPUT_GS mov(r11, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm8, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm9, ymm0, ymm0) ZGEMM_OUTPUT_GS mov(r12, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm12, ymm0, ymm0) ZGEMM_OUTPUT_GS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_GS_BETA_NZ vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_GS + + + jmp(.ZDONE) // jump to end. @@ -2693,39 +2683,51 @@ void bli_zgemm_haswell_asm_3x4 //CASE 3: Default case with multiplication // beta not equal to (+/-1) or zero, do normal multiplication. label(.GEN_BETA_NOT_REAL_ONE) + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm4, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm5, ymm0, ymm0) ZGEMM_OUTPUT_RS mov(r11, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm8, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm9, ymm0, ymm0) ZGEMM_OUTPUT_RS mov(r12, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm12, ymm0, ymm0) ZGEMM_OUTPUT_RS add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ vaddpd(ymm13, ymm0, ymm0) ZGEMM_OUTPUT_RS + + + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) + cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. jz(.ZROWSTORBZ) // jump to row storage case @@ -2772,6 +2774,7 @@ void bli_zgemm_haswell_asm_3x4 label(.ZROWSTORBZ) + vmovupd(ymm4, mem(rcx)) vmovupd(ymm5, mem(rcx, rdx, 1)) @@ -2781,28 +2784,34 @@ void bli_zgemm_haswell_asm_3x4 vmovupd(ymm12, mem(r12)) vmovupd(ymm13, mem(r12, rdx, 1)) + + + label(.ZDONE) + vzeroupper() - end_asm( + + + end_asm( : // output operands (none) : // input operands - [alpha_mul_type] "m" (alpha_mul_type), - [beta_mul_type] "m" (beta_mul_type), - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [alpha_mul_type] "m" (alpha_mul_type), + [beta_mul_type] "m" (beta_mul_type), + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", diff --git a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c index 939cab78f2..79e7cd0f28 100644 --- a/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_l_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -82,22 +82,22 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 float* beta = bli_sm1; begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(16), rdi) // set rs_b = PACKNR = 16 lea(mem(, rdi, 4), rdi) // rs_b *= sizeof(float) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -106,45 +106,45 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 lea(mem(, r9 , 4), r9) // rs_c *= sizeof(float) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 4), r10) // cs_c *= sizeof(float) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) @@ -152,51 +152,51 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 8*4), ymm2) vbroadcastss(mem(rax, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 10*4), ymm2) vbroadcastss(mem(rax, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*4)) - + vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 14*4), ymm2) vbroadcastss(mem(rax, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 16*4), ymm2) vbroadcastss(mem(rax, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) @@ -204,144 +204,144 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 20*4), ymm2) vbroadcastss(mem(rax, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 22*4), ymm2) vbroadcastss(mem(rax, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*6*4), rax) // a += 4*6 (unroll x mr) add(imm(4*16*4), rbx) // b += 4*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rax, 2*4), ymm2) vbroadcastss(mem(rax, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rax, 4*4), ymm2) vbroadcastss(mem(rax, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*6*4), rax) // a += 1*6 (unroll x mr) add(imm(1*16*4), rbx) // b += 1*16 (unroll x nr) - + vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - + + + mov(var(alpha), rbx) // load address of alpha vbroadcastss(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // load cs_b = 1 lea(mem(, rsi, 4), rsi) // cs_b *= sizeof(float) - + lea(mem(rcx, rsi, 8), rdx) // load address of b11 + 8*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+8*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231ps(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231ps(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231ps(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(float) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 0*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 0*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 0*8)) // prefetch c11 + 2*rs_c @@ -349,12 +349,12 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 prefetch(0, mem(rdx, r9, 1, 0*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 0*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..07 ) ( beta08..0F ) // ymm6 ymm7 = ( beta10..17 ) ( beta18..1F ) @@ -362,18 +362,18 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 // ymm10 ymm11 = ( beta30..37 ) ( beta38..3F ) // ymm12 ymm13 = ( beta40..47 ) ( beta48..4F ) // ymm14 ymm15 = ( beta50..57 ) ( beta58..5F ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+8*cs_b // Note: rdi = rs_b - + // iteration 0 ------------- - + vbroadcastss(mem(0+0*6)*4(rax), ymm0) // ymm0 = (1/alpha00) - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulps(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) @@ -381,23 +381,23 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm0, ymm4, ymm4) // ymm4 /= alpha00 vdivps(ymm0, ymm5, ymm5) // ymm5 /= alpha00 #endif - + vmovups(ymm4, mem(rcx)) // store ( beta00..beta07 ) = ymm4 vmovups(ymm5, mem(rdx)) // store ( beta08..beta0F ) = ymm5 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 1 ------------- - + vbroadcastss(mem(1+0*6)*4(rax), ymm0) // ymm0 = alpha10 vbroadcastss(mem(1+1*6)*4(rax), ymm1) // ymm1 = (1/alpha11) - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha10 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha10 * ymm5 - + vsubps(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubps(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) vmulps(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) @@ -405,28 +405,28 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm1, ymm6, ymm6) // ymm6 /= alpha11 vdivps(ymm1, ymm7, ymm7) // ymm7 /= alpha11 #endif - + vmovups(ymm6, mem(rcx)) // store ( beta10..beta17 ) = ymm6 vmovups(ymm7, mem(rdx)) // store ( beta18..beta1F ) = ymm7 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 2 ------------- - + vbroadcastss(mem(2+0*6)*4(rax), ymm0) // ymm0 = alpha20 vbroadcastss(mem(2+1*6)*4(rax), ymm1) // ymm1 = alpha21 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha20 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha20 * ymm5 - + vbroadcastss(mem(2+2*6)*4(rax), ymm0) // ymm0 = (1/alpha22) - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha21 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha21 * ymm7 - + vsubps(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubps(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) vmulps(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) @@ -434,33 +434,33 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm0, ymm8, ymm8) // ymm8 /= alpha22 vdivps(ymm0, ymm9, ymm9) // ymm9 /= alpha22 #endif - + vmovups(ymm8, mem(rcx)) // store ( beta20..beta27 ) = ymm8 vmovups(ymm9, mem(rdx)) // store ( beta28..beta2F ) = ymm9 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 3 ------------- - + vbroadcastss(mem(3+0*6)*4(rax), ymm0) // ymm0 = alpha30 vbroadcastss(mem(3+1*6)*4(rax), ymm1) // ymm1 = alpha31 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha30 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha30 * ymm5 - + vbroadcastss(mem(3+2*6)*4(rax), ymm0) // ymm0 = alpha32 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha31 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha31 * ymm7 - + vbroadcastss(mem(3+3*6)*4(rax), ymm1) // ymm0 = (1/alpha33) - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha32 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha32 * ymm9 - + vsubps(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubps(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) vmulps(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) @@ -468,38 +468,38 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm1, ymm10, ymm10) // ymm10 /= alpha33 vdivps(ymm1, ymm11, ymm11) // ymm11 /= alpha33 #endif - + vmovups(ymm10, mem(rcx)) // store ( beta30..beta37 ) = ymm10 vmovups(ymm11, mem(rdx)) // store ( beta38..beta3F ) = ymm11 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 4 ------------- - + vbroadcastss(mem(4+0*6)*4(rax), ymm0) // ymm0 = alpha40 vbroadcastss(mem(4+1*6)*4(rax), ymm1) // ymm1 = alpha41 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha40 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha40 * ymm5 - + vbroadcastss(mem(4+2*6)*4(rax), ymm0) // ymm0 = alpha42 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha41 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha41 * ymm7 - + vbroadcastss(mem(4+3*6)*4(rax), ymm1) // ymm1 = alpha43 - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha42 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha42 * ymm9 - + vbroadcastss(mem(4+4*6)*4(rax), ymm0) // ymm0 = (1/alpha44) - + vfmadd231ps(ymm1, ymm10, ymm2) // ymm2 += alpha43 * ymm10 vfmadd231ps(ymm1, ymm11, ymm3) // ymm3 += alpha43 * ymm11 - + vsubps(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubps(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) vmulps(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) @@ -507,43 +507,43 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm0, ymm12, ymm12) // ymm12 /= alpha44 vdivps(ymm0, ymm13, ymm13) // ymm13 /= alpha44 #endif - + vmovups(ymm12, mem(rcx)) // store ( beta40..beta47 ) = ymm12 vmovups(ymm13, mem(rdx)) // store ( beta48..beta4F ) = ymm13 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 5 ------------- - + vbroadcastss(mem(5+0*6)*4(rax), ymm0) // ymm0 = alpha50 vbroadcastss(mem(5+1*6)*4(rax), ymm1) // ymm1 = alpha51 - + vmulps(ymm0, ymm4, ymm2) // ymm2 = alpha50 * ymm4 vmulps(ymm0, ymm5, ymm3) // ymm3 = alpha50 * ymm5 - + vbroadcastss(mem(5+2*6)*4(rax), ymm0) // ymm0 = alpha52 - + vfmadd231ps(ymm1, ymm6, ymm2) // ymm2 += alpha51 * ymm6 vfmadd231ps(ymm1, ymm7, ymm3) // ymm3 += alpha51 * ymm7 - + vbroadcastss(mem(5+3*6)*4(rax), ymm1) // ymm1 = alpha53 - + vfmadd231ps(ymm0, ymm8, ymm2) // ymm2 += alpha52 * ymm8 vfmadd231ps(ymm0, ymm9, ymm3) // ymm3 += alpha52 * ymm9 - + vbroadcastss(mem(5+4*6)*4(rax), ymm0) // ymm0 = alpha54 - + vfmadd231ps(ymm1, ymm10, ymm2) // ymm2 += alpha53 * ymm10 vfmadd231ps(ymm1, ymm11, ymm3) // ymm3 += alpha53 * ymm11 - + vbroadcastss(mem(5+5*6)*4(rax), ymm1) // ymm1 = (1/alpha55) - + vfmadd231ps(ymm0, ymm12, ymm2) // ymm2 += alpha54 * ymm12 vfmadd231ps(ymm0, ymm13, ymm3) // ymm3 += alpha54 * ymm13 - + vsubps(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubps(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulps(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulps(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) @@ -551,189 +551,189 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vdivps(ymm1, ymm14, ymm14) // ymm14 /= alpha55 vdivps(ymm1, ymm15, ymm15) // ymm15 /= alpha55 #endif - + vmovups(ymm14, mem(rcx)) // store ( beta50..beta57 ) = ymm14 vmovups(ymm15, mem(rdx)) // store ( beta58..beta5F ) = ymm15 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - - - - - + + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 8), rdx) // load address of c11 + 8*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. jz(.SROWSTORED) // jump to row storage case - - - + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. jz(.SCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.SGENSTORED) - - + + vmovaps(ymm4, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm6, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm8, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm10, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm12, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm14, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 8*cs_c - - + + vmovaps(ymm5, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm7, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm9, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm11, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm13, ymm0) SGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovaps(ymm15, ymm0) SGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.SDONE) - - - + + + label(.SROWSTORED) - - + + vmovups(ymm4, mem(rcx)) add(rdi, rcx) vmovups(ymm5, mem(rdx)) add(rdi, rdx) - + vmovups(ymm6, mem(rcx)) add(rdi, rcx) vmovups(ymm7, mem(rdx)) add(rdi, rdx) - + vmovups(ymm8, mem(rcx)) add(rdi, rcx) vmovups(ymm9, mem(rdx)) add(rdi, rdx) - + vmovups(ymm10, mem(rcx)) add(rdi, rcx) vmovups(ymm11, mem(rdx)) add(rdi, rdx) - + vmovups(ymm12, mem(rcx)) add(rdi, rcx) vmovups(ymm13, mem(rdx)) add(rdi, rdx) - + vmovups(ymm14, mem(rcx)) //add(rdi, rcx) vmovups(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.SDONE) - - - + + + label(.SCOLSTORED) - - + + vunpcklps(ymm6, ymm4, ymm0) vunpcklps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - + + vunpckhps(ymm6, ymm4, ymm0) vunpckhps(ymm10, ymm8, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm14, ymm12, ymm0) vunpckhps(ymm14, ymm12, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) @@ -742,46 +742,46 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - + + vunpcklps(ymm7, ymm5, ymm0) vunpcklps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx)) // store ( gamma08..gamma38 ) vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma09..gamma39 ) vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma0C..gamma3C ) vmovups(xmm3, mem(rcx, r15, 1)) // store ( gamma0D..gamma3D ) - + vunpckhps(ymm7, ymm5, ymm0) vunpckhps(ymm11, ymm9, ymm1) vshufps(imm(0x4e), ymm1, ymm0, ymm2) vblendps(imm(0xcc), ymm2, ymm0, ymm0) vblendps(imm(0x33), ymm2, ymm1, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma0A..gamma3A ) vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma0B..gamma3B ) vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma0E..gamma3E ) vmovups(xmm3, mem(rcx, r10, 1)) // store ( gamma0F..gamma3F ) - + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - + vunpcklps(ymm15, ymm13, ymm0) vunpckhps(ymm15, ymm13, ymm1) - + vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovlpd(xmm0, mem(r14)) // store ( gamma48..gamma58 ) vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma49..gamma59 ) vmovlpd(xmm1, mem(r14, rsi, 2)) // store ( gamma4A..gamma5A ) @@ -790,33 +790,34 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma4D..gamma5D ) vmovlpd(xmm3, mem(r14, r13, 2)) // store ( gamma4E..gamma5E ) vmovhpd(xmm3, mem(r14, r10, 1)) // store ( gamma4F..gamma5F ) - + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - + + + + label(.SDONE) - + vzeroupper() - + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a10] "m" (a10), // 2 - [b01] "m" (b01), // 3 - [beta] "m" (beta), // 4 - [alpha] "m" (alpha), // 5 - [a11] "m" (a11), // 6 - [b11] "m" (b11), // 7 - [c11] "m" (c11), // 8 - [rs_c] "m" (rs_c), // 9 - [cs_c] "m" (cs_c) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a10] "m" (a10), // 2 + [b01] "m" (b01), // 3 + [beta] "m" (beta), // 4 + [alpha] "m" (alpha), // 5 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -844,17 +845,17 @@ void bli_sgemmtrsm_l_haswell_asm_6x16 vmovhpd(xmm1, mem(rcx, r10, 1))*/ void bli_dgemmtrsm_l_haswell_asm_6x8 -( - dim_t k0, - double* restrict alpha, - double* restrict a10, - double* restrict a11, - double* restrict b01, - double* restrict b11, - double* restrict c11, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx -) + ( + dim_t k0, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); //void* a_next = bli_auxinfo_next_a( data ); @@ -870,22 +871,22 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 double* beta = bli_dm1; begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a10), rax) // load address of a. mov(var(b01), rbx) // load address of b. - + add(imm(32*4), rbx) // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - + mov(var(b11), rcx) // load address of b11 mov(imm(8), rdi) // set rs_b = PACKNR = 8 lea(mem(, rdi, 8), rdi) // rs_b *= sizeof(double) - + // NOTE: c11, rs_c, and cs_c aren't // needed for a while, but we load // them now to avoid stalling later. @@ -894,97 +895,99 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 lea(mem(, r9 , 8), r9) // rs_c *= sizeof(double) mov(var(k_left)0, r10) // load cs_c lea(mem(, r10, 8), r10) // cs_c *= sizeof(double) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -992,145 +995,145 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + // ymm4..ymm15 = -a10 * b01 - - - - + + + + mov(var(alpha), rbx) // load address of alpha vbroadcastsd(mem(rbx), ymm3) // load alpha and duplicate - - - - + + + + mov(imm(1), rsi) // set cs_b = 1 lea(mem(, rsi, 8), rsi) // cs_b *= sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of b11 + 4*cs_b - + mov(rcx, r11) // save rcx = b11 for later mov(rdx, r14) // save rdx = b11+4*cs_b for later - - + + // b11 := alpha * b11 - a10 * b01 vfmsub231pd(mem(rcx), ymm3, ymm4) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm5) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm6) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm7) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm8) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm9) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm10) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm11) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm12) add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm13) add(rdi, rdx) - + vfmsub231pd(mem(rcx), ymm3, ymm14) //add(rdi, rcx) vfmsub231pd(mem(rdx), ymm3, ymm15) //add(rdi, rdx) - - - + + + // prefetch c11 - + #if 0 mov(r8, rcx) // load address of c11 from r8 // Note: r9 = rs_c * sizeof(double) - + lea(mem(r9 , r9 , 2), r13) // r13 = 3*rs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c11 + 3*rs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c11 + 0*rs_c prefetch(0, mem(rcx, r9, 1, 7*8)) // prefetch c11 + 1*rs_c prefetch(0, mem(rcx, r9 , 2, 7*8)) // prefetch c11 + 2*rs_c @@ -1138,12 +1141,12 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 prefetch(0, mem(rdx, r9, 1, 7*8)) // prefetch c11 + 4*rs_c prefetch(0, mem(rdx, r9 , 2, 7*8)) // prefetch c11 + 5*rs_c #endif - - - - + + + + // trsm computation begins here - + // Note: contents of b11 are stored as // ymm4 ymm5 = ( beta00..03 ) ( beta04..07 ) // ymm6 ymm7 = ( beta10..13 ) ( beta14..17 ) @@ -1151,18 +1154,18 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 // ymm10 ymm11 = ( beta30..33 ) ( beta34..37 ) // ymm12 ymm13 = ( beta40..43 ) ( beta44..47 ) // ymm14 ymm15 = ( beta50..53 ) ( beta54..57 ) - - + + mov(var(a11), rax) // load address of a11 - + mov(r11, rcx) // recall address of b11 mov(r14, rdx) // recall address of b11+4*cs_b // Note: rdi = rs_b - + // iteration 0 ------------- - + vbroadcastsd(mem(0+0*6)*8(rax), ymm0) // ymm0 = (1/alpha00) - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm4, ymm4) // ymm4 *= (1/alpha00) vmulpd(ymm0, ymm5, ymm5) // ymm5 *= (1/alpha00) @@ -1170,23 +1173,23 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm0, ymm4, ymm4) // ymm4 /= alpha00 vdivpd(ymm0, ymm5, ymm5) // ymm5 /= alpha00 #endif - + vmovupd(ymm4, mem(rcx)) // store ( beta00..beta03 ) = ymm4 vmovupd(ymm5, mem(rdx)) // store ( beta04..beta07 ) = ymm5 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 1 ------------- - + vbroadcastsd(mem(1+0*6)*8(rax), ymm0) // ymm0 = alpha10 vbroadcastsd(mem(1+1*6)*8(rax), ymm1) // ymm1 = (1/alpha11) - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha10 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha10 * ymm5 - + vsubpd(ymm2, ymm6, ymm6) // ymm6 -= ymm2 vsubpd(ymm3, ymm7, ymm7) // ymm7 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm1, ymm6, ymm6) // ymm6 *= (1/alpha11) vmulpd(ymm1, ymm7, ymm7) // ymm7 *= (1/alpha11) @@ -1194,28 +1197,28 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm1, ymm6, ymm6) // ymm6 /= alpha11 vdivpd(ymm1, ymm7, ymm7) // ymm7 /= alpha11 #endif - + vmovupd(ymm6, mem(rcx)) // store ( beta10..beta13 ) = ymm6 vmovupd(ymm7, mem(rdx)) // store ( beta14..beta17 ) = ymm7 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 2 ------------- - + vbroadcastsd(mem(2+0*6)*8(rax), ymm0) // ymm0 = alpha20 vbroadcastsd(mem(2+1*6)*8(rax), ymm1) // ymm1 = alpha21 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha20 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha20 * ymm5 - + vbroadcastsd(mem(2+2*6)*8(rax), ymm0) // ymm0 = (1/alpha22) - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha21 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha21 * ymm7 - + vsubpd(ymm2, ymm8, ymm8) // ymm8 -= ymm2 vsubpd(ymm3, ymm9, ymm9) // ymm9 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm8, ymm8) // ymm8 *= (1/alpha22) vmulpd(ymm0, ymm9, ymm9) // ymm9 *= (1/alpha22) @@ -1223,33 +1226,33 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm0, ymm8, ymm8) // ymm8 /= alpha22 vdivpd(ymm0, ymm9, ymm9) // ymm9 /= alpha22 #endif - + vmovupd(ymm8, mem(rcx)) // store ( beta20..beta23 ) = ymm8 vmovupd(ymm9, mem(rdx)) // store ( beta24..beta27 ) = ymm9 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 3 ------------- - + vbroadcastsd(mem(3+0*6)*8(rax), ymm0) // ymm0 = alpha30 vbroadcastsd(mem(3+1*6)*8(rax), ymm1) // ymm1 = alpha31 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha30 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha30 * ymm5 - + vbroadcastsd(mem(3+2*6)*8(rax), ymm0) // ymm0 = alpha32 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha31 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha31 * ymm7 - + vbroadcastsd(mem(3+3*6)*8(rax), ymm1) // ymm1 = (1/alpha33) - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha32 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha32 * ymm9 - + vsubpd(ymm2, ymm10, ymm10) // ymm10 -= ymm2 vsubpd(ymm3, ymm11, ymm11) // ymm11 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm1, ymm10, ymm10) // ymm10 *= (1/alpha33) vmulpd(ymm1, ymm11, ymm11) // ymm11 *= (1/alpha33) @@ -1257,38 +1260,38 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm1, ymm10, ymm10) // ymm10 /= alpha33 vdivpd(ymm1, ymm11, ymm11) // ymm11 /= alpha33 #endif - + vmovupd(ymm10, mem(rcx)) // store ( beta30..beta33 ) = ymm10 vmovupd(ymm11, mem(rdx)) // store ( beta34..beta37 ) = ymm11 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 4 ------------- - + vbroadcastsd(mem(4+0*6)*8(rax), ymm0) // ymm0 = alpha40 vbroadcastsd(mem(4+1*6)*8(rax), ymm1) // ymm1 = alpha41 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha40 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha40 * ymm5 - + vbroadcastsd(mem(4+2*6)*8(rax), ymm0) // ymm0 = alpha42 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha41 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha41 * ymm7 - + vbroadcastsd(mem(4+3*6)*8(rax), ymm1) // ymm1 = alpha43 - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha42 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha42 * ymm9 - + vbroadcastsd(mem(4+4*6)*8(rax), ymm0) // ymm4 = (1/alpha44) - + vfmadd231pd(ymm1, ymm10, ymm2) // ymm2 += alpha43 * ymm10 vfmadd231pd(ymm1, ymm11, ymm3) // ymm3 += alpha43 * ymm11 - + vsubpd(ymm2, ymm12, ymm12) // ymm12 -= ymm2 vsubpd(ymm3, ymm13, ymm13) // ymm13 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm0, ymm12, ymm12) // ymm12 *= (1/alpha44) vmulpd(ymm0, ymm13, ymm13) // ymm13 *= (1/alpha44) @@ -1296,43 +1299,43 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm0, ymm12, ymm12) // ymm12 /= alpha44 vdivpd(ymm0, ymm13, ymm13) // ymm13 /= alpha44 #endif - + vmovupd(ymm12, mem(rcx)) // store ( beta40..beta43 ) = ymm12 vmovupd(ymm13, mem(rdx)) // store ( beta44..beta47 ) = ymm13 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - + // iteration 5 ------------- - + vbroadcastsd(mem(5+0*6)*8(rax), ymm0) // ymm0 = alpha50 vbroadcastsd(mem(5+1*6)*8(rax), ymm1) // ymm1 = alpha51 - + vmulpd(ymm0, ymm4, ymm2) // ymm2 = alpha50 * ymm4 vmulpd(ymm0, ymm5, ymm3) // ymm3 = alpha50 * ymm5 - + vbroadcastsd(mem(5+2*6)*8(rax), ymm0) // ymm0 = alpha52 - + vfmadd231pd(ymm1, ymm6, ymm2) // ymm2 += alpha51 * ymm6 vfmadd231pd(ymm1, ymm7, ymm3) // ymm3 += alpha51 * ymm7 - + vbroadcastsd(mem(5+3*6)*8(rax), ymm1) // ymm1 = alpha53 - + vfmadd231pd(ymm0, ymm8, ymm2) // ymm2 += alpha52 * ymm8 vfmadd231pd(ymm0, ymm9, ymm3) // ymm3 += alpha52 * ymm9 - + vbroadcastsd(mem(5+4*6)*8(rax), ymm0) // ymm0 = alpha54 - + vfmadd231pd(ymm1, ymm10, ymm2) // ymm2 += alpha53 * ymm10 vfmadd231pd(ymm1, ymm11, ymm3) // ymm3 += alpha53 * ymm11 - + vbroadcastsd(mem(5+5*6)*8(rax), ymm1) // ymm1 = (1/alpha55) - + vfmadd231pd(ymm0, ymm12, ymm2) // ymm2 += alpha54 * ymm12 vfmadd231pd(ymm0, ymm13, ymm3) // ymm3 += alpha54 * ymm13 - + vsubpd(ymm2, ymm14, ymm14) // ymm14 -= ymm2 vsubpd(ymm3, ymm15, ymm15) // ymm15 -= ymm3 - + #ifdef BLIS_ENABLE_TRSM_PREINVERSION vmulpd(ymm1, ymm14, ymm14) // ymm14 *= (1/alpha55) vmulpd(ymm1, ymm15, ymm15) // ymm15 *= (1/alpha55) @@ -1340,150 +1343,150 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vdivpd(ymm1, ymm14, ymm14) // ymm14 /= alpha55 vdivpd(ymm1, ymm15, ymm15) // ymm15 /= alpha55 #endif - + vmovupd(ymm14, mem(rcx)) // store ( beta50..beta53 ) = ymm14 vmovupd(ymm15, mem(rdx)) // store ( beta54..beta57 ) = ymm15 add(rdi, rcx) // rcx += rs_b add(rdi, rdx) // rdx += rs_b - - - - + + + + mov(r8, rcx) // load address of c11 from r8 mov(r9, rdi) // load rs_c (in bytes) from r9 mov(r10, rsi) // load cs_c (in bytes) from r10 - + lea(mem(rcx, rsi, 4), rdx) // load address of c11 + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c11 + 4*rs_c; - + // These are used in the macros below. lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - - + + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - - + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + // if neither row- or column- // stored, use general case. label(.DGENSTORED) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c11 + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c11 += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + jmp(.DDONE) - - - + + + label(.DROWSTORED) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1492,27 +1495,27 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1521,50 +1524,49 @@ void bli_dgemmtrsm_l_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm3) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm3, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - - - + + + + label(.DDONE) - + vzeroupper() - + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a10] "m" (a10), // 2 - [b01] "m" (b01), // 3 - [beta] "m" (beta), // 4 - [alpha] "m" (alpha), // 5 - [a11] "m" (a11), // 6 - [b11] "m" (b11), // 7 - [c11] "m" (c11), // 8 - [rs_c] "m" (rs_c), // 9 - [cs_c] "m" (cs_c) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a10] "m" (a10), // 2 + [b01] "m" (b01), // 3 + [beta] "m" (beta), // 4 + [alpha] "m" (alpha), // 5 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", diff --git a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c index bd9d338b3c..e183df8e19 100644 --- a/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemmtrsm_u_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -805,6 +805,8 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vzeroupper() + + end_asm( : // output operands (none) : // input operands @@ -848,7 +850,7 @@ void bli_sgemmtrsm_u_haswell_asm_6x16 vmovhpd(xmm1, mem(rcx, r10, 1))*/ void bli_dgemmtrsm_u_haswell_asm_6x8 -( + ( dim_t k0, double* restrict alpha, double* restrict a10, @@ -858,7 +860,7 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 double* restrict c11, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx -) + ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); //void* a_next = bli_auxinfo_next_a( data ); @@ -938,6 +940,8 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vmovapd(mem(rbx, -1*32), ymm1) // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -963,7 +967,7 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vmovapd(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 76*8)) + prefetch(0, mem(rax, 80*8)) vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) @@ -1559,6 +1563,7 @@ void bli_dgemmtrsm_u_haswell_asm_6x8 vzeroupper() + end_asm( : // output operands (none) : // input operands diff --git a/kernels/haswell/3/sup/CMakeLists.txt b/kernels/haswell/3/sup/CMakeLists.txt deleted file mode 100644 index e5ed6183c2..0000000000 --- a/kernels/haswell/3/sup/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(haswell_3sup - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_d6x8m.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_d6x8n.c - #${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_s6x16m.c - #${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_s6x16n.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_d6x8m.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_d6x8n.c - #${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_s6x16m.c - #${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_s6x16n.c - ) -target_compile_options(haswell_3sup PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(haswell_3sup PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() -add_subdirectory(d6x8) -#add_subdirectory(s6x16) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c index dc81b2d913..15ec2bef3d 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c index 65c985ef1a..73180d8f79 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c index 9962e1a95e..d03ab88477 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c index 3af06075a8..ac9aa21939 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 05c240d2d1..328f901e3c 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,6 +38,405 @@ #define BLIS_ASM_SYNTAX_ATT #include "bli_x86_asm_macros.h" +static const int64_t mask_3[4] = {-1, -1, -1, 0}; +static const int64_t mask_1[4] = {-1, 0, 0, 0}; + +static void bli_dgemmsup_rv_haswell_asm_6x7m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +static void bli_dgemmsup_rv_haswell_asm_6x5m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +static void bli_dgemmsup_rv_haswell_asm_6x3m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +static void bli_dgemmsup_rv_haswell_asm_6x1m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +#define C_TRANSPOSE_6x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + /*Storing it back to C matrix.*/ \ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + /*Moving to operate on last 2 rows of 6 rows.*/ \ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2, 3 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ + vfmadd231pd(mem(rdx, rax, 1), xmm15, xmm3)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpckhpd(ymm(R8), ymm(R7), ymm1)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vunpckhpd(ymm(R10), ymm(R9), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm7)\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm9)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm5)\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm7)\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm9)\ +\ + vmovupd(ymm5, mem(rcx ))\ + vmovupd(ymm7, mem(rcx, rsi, 1))\ + vmovupd(ymm9, mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vunpckhpd(ymm(R12), ymm(R11), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ +\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Storing transposed 4x4 tile back to C matrix*/\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Storing transposed 2x4 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpckhpd(ymm(R8), ymm(R7), ymm1)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vunpckhpd(ymm(R10), ymm(R9), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm7)\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm9)\ +\ + /*Storing transposed 4x3 tile back to C matrix*/\ + vmovupd(ymm5, mem(rcx ))\ + vmovupd(ymm7, mem(rcx, rsi, 1))\ + vmovupd(ymm9, mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vunpckhpd(ymm(R12), ymm(R11), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + /*Storing transposed 2x3 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2, 3 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm15, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm15, xmm2)\ + vfmadd231pd(mem(rdx, rax, 1), xmm15, xmm3)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm5)\ + vmovupd(ymm5, mem(rcx ))\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ + vfmadd231pd(mem(rdx ), xmm15, xmm0)\ +\ + vmovupd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_6x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10, R11, R12) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Storing transposed 4x4 tile back to C matrix*/\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm3)\ +\ + /*Storing transposed 4x2 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2))\ + vmovupd(xmm3, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R8), ymm(R7), ymm0)\ + vunpcklpd(ymm(R10), ymm(R9), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm5)\ +\ + /*Storing transposed 4x1 tile back to C matrix*/\ + vmovupd(ymm5, mem(rcx ))\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R12), ymm(R11), ymm0)\ +\ + /*Storing transposed 2x1 tile back to C matrix*/\ + vmovupd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_6x3_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm3, xmm0)\ + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x3_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm4)\ +\ + vmovupd(xmm0, mem(rdx ))\ + vmovupd(xmm1, mem(rdx, rsi, 1))\ + vmovupd(xmm2, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_6x1_TILE(R1, R2, R3, R4, R5, R6) \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vmovupd(ymm(R1), mem(rcx ))\ +\ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*0, 1, 2 holds final result*/ \ + vfmadd231pd(mem(rdx ), xmm3, xmm0)\ + vmovupd(xmm0, mem(rdx ))\ + +#define C_TRANSPOSE_6x1_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ +\ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ +\ + vmovupd(xmm0, mem(rdx ))\ /* rrr: -------- ------ -------- @@ -108,93 +507,82 @@ void bli_dgemmsup_rv_haswell_asm_6x8m double* restrict bj = b; double* restrict ai = a; - if ( 6 <= n_left ) - { - const dim_t nr_cur = 6; - - bli_dgemmsup_rv_haswell_asm_6x6m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 4 <= n_left ) - { - const dim_t nr_cur = 4; - - bli_dgemmsup_rv_haswell_asm_6x4m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_dgemmsup_rv_haswell_asm_6x2m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) + switch(n_left) { -#if 0 - const dim_t nr_cur = 1; - - bli_dgemmsup_r_haswell_ref - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); -#else - dim_t ps_a0 = bli_auxinfo_ps_a( data ); - - if ( ps_a0 == 6 * rs_a0 ) + case 7: + { + bli_dgemmsup_rv_haswell_asm_6x7m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 6: { - // Since A is not packed, we can use one gemv. - bli_dgemv_ex + bli_dgemmsup_rv_haswell_asm_6x6m ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx ); + break; } - else + case 5: { - const dim_t mr = 6; - - // Since A is packed into row panels, we must use a loop over - // gemv. - dim_t m_iter = ( m0 + mr - 1 ) / mr; - dim_t m_left = m0 % mr; - - double* restrict ai_ii = ai; - double* restrict cij_ii = cij; - - for ( dim_t ii = 0; ii < m_iter; ii += 1 ) - { - dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) - ? mr : m_left ); - - bli_dgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, - alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, - beta, cij_ii, rs_c0, cntx, NULL - ); - cij_ii += mr*rs_c0; ai_ii += ps_a0; - } + bli_dgemmsup_rv_haswell_asm_6x5m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 4: + { + bli_dgemmsup_rv_haswell_asm_6x4m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 3: + { + bli_dgemmsup_rv_haswell_asm_6x3m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 2: + { + bli_dgemmsup_rv_haswell_asm_6x2m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 1: + { + bli_dgemmsup_rv_haswell_asm_6x1m + ( + conja, conjb, m0, n_left, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + default: + { + break; } -#endif } return; } @@ -600,6 +988,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DROWSTORED) + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) @@ -614,39 +1003,36 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) vmovupd(ymm7, mem(rcx, 1*32)) - add(rdi, rcx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rbx, 1*32)) + add(rdi, rbx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rbx, 1*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) - vmovupd(ymm13, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rdx, 1*32)) + add(rdi, rdx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) - vmovupd(ymm14, mem(rcx, 0*32)) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) - vmovupd(ymm15, mem(rcx, 1*32)) - //add(rdi, rcx) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rdx, 1*32)) jmp(.DDONE) // jump to end. @@ -916,53 +1302,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m double* restrict ai = a + m_iter * ps_a; double* restrict bj = b; -#if 0 - // We add special handling for slightly inflated MR blocksizes - // at edge cases, up to a maximum of 9. - if ( 6 < m_left ) - { - dgemmsup_ker_ft ker_fp1 = NULL; - dgemmsup_ker_ft ker_fp2 = NULL; - dim_t mr1, mr2; - - if ( m_left == 7 ) - { - mr1 = 4; mr2 = 3; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_3x8; - } - else if ( m_left == 8 ) - { - mr1 = 4; mr2 = 4; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_4x8; - } - else // if ( m_left == 9 ) - { - mr1 = 4; mr2 = 5; - ker_fp1 = bli_dgemmsup_rv_haswell_asm_4x8; - ker_fp2 = bli_dgemmsup_rv_haswell_asm_5x8; - } - - ker_fp1 - ( - conja, conjb, mr1, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += mr1*rs_c0; ai += mr1*rs_a0; - - ker_fp2 - ( - conja, conjb, mr2, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - - return; - } -#endif - dgemmsup_ker_ft ker_fps[6] = { NULL, @@ -8129,7 +8468,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) } -void bli_dgemmsup_rv_haswell_asm_6x6m +static void bli_dgemmsup_rv_haswell_asm_6x7m ( conj_t conja, conj_t conjb, @@ -8146,6 +8485,31 @@ void bli_dgemmsup_rv_haswell_asm_6x6m ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -8168,13 +8532,15 @@ void bli_dgemmsup_rv_haswell_asm_6x6m uint64_t ps_a = bli_auxinfo_ps_a( data ); uint64_t ps_a8 = ps_a * sizeof( double ); - if ( m_iter == 0 ) goto consider_edge_cases; + int64_t const *mask_vec = mask_3; - // ------------------------------------------------------------------------- + if ( m_iter == 0 ) goto consider_edge_cases_7; + // ------------------------------------------------------------------------- begin_asm() - //vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a @@ -8185,25 +8551,2205 @@ void bli_dgemmsup_rv_haswell_asm_6x6m lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - //mov(var(b), rbx) // load address of b. mov(var(rs_b), r10) // load rs_b - //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) - //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). mov(var(c), r12) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(m_iter), r11) // ii = m_iter; - // During preamble and loops: - // r12 = rcx = c + label(.DLOOP6X7I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm3) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + lea(mem(rdx, rsi, 2), rcx) // rcx = c + 5*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + vmulpd(ymm0, ymm7, ymm7) // scale by alpha + vmulpd(ymm0, ymm8, ymm8) + + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + //Loads 4 element + vmovupd(ymm3, mem(rcx, 0*32)) + //Loads 3 elements based on mask_3 mask vector + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + //-----------------------2 + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rbx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 1*32)) + + add(rdi, rbx) + //-----------------------3 + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 1*32)) + + //-----------------------4 + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rdx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 1*32)) + + add(rdi, rdx) + //-----------------------5 + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm13) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm14) + + vmovupd(ymm13, mem(rdx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rdx, 1*32)) + + //-----------------------6 + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + C_TRANSPOSE_6x7_TILE(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + //-----------------------6 + + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x7_TILE_BZ(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.RESETPARAM) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + jmp(.DDONE) + + label(.DDONE) + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X7I) // iterate again if ii != 0. + + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [mask_vec] "m" (mask_vec), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", + "memory" + ) + + consider_edge_cases_7: + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x7, + bli_dgemmsup_rv_haswell_asm_2x7, + bli_dgemmsup_rv_haswell_asm_3x7, + bli_dgemmsup_rv_haswell_asm_4x7, + bli_dgemmsup_rv_haswell_asm_5x7 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +static void bli_dgemmsup_rv_haswell_asm_6x5m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_1; + + if ( m_iter == 0 ) goto consider_edge_cases_5; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X5I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm3) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 3*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 3*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 3*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + vmulpd(ymm0, ymm7, ymm7) // scale by alpha + vmulpd(ymm0, ymm8, ymm8) + + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + //Loads 4 element + vmovupd(ymm3, mem(rcx, 0*32)) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + //-----------------------2 + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vmovupd(ymm7, mem(rbx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 1*32)) + + add(rdi, rbx) + //-----------------------3 + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm9, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 1*32)) + + //-----------------------4 + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + vmovupd(ymm11, mem(rdx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 1*32)) + + add(rdi, rdx) + //-----------------------5 + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm13) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm14) + + vmovupd(ymm13, mem(rdx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rdx, 1*32)) + + //-----------------------6 + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + C_TRANSPOSE_6x5_TILE(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------5 + + vmovupd(ymm13, mem(rcx, 0*32)) + vmaskmovpd(ymm14, ymm15, mem(rcx, 1*32)) + + //-----------------------6 + + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x5_TILE_BZ(3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14) + jmp(.RESETPARAM) + + label(.RESETPARAM) + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + jmp(.DDONE) + + label(.DDONE) + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X5I) // iterate again if ii != 0. + + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [mask_vec] "m" (mask_vec), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", + "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "memory" + ) + + consider_edge_cases_5: + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x5, + bli_dgemmsup_rv_haswell_asm_2x5, + bli_dgemmsup_rv_haswell_asm_3x5, + bli_dgemmsup_rv_haswell_asm_4x5, + bli_dgemmsup_rv_haswell_asm_5x5 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); +} + +static void bli_dgemmsup_rv_haswell_asm_6x3m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// + +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_3; + + if ( m_iter == 0 ) goto consider_edge_cases_nleft_3; + + begin_asm() + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X3I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 2*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 2), rbx) // load address of c + 4*rs_c; + + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rbx, 0*32)) + add(rdi, rbx) + + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm10) + vmaskmovpd(ymm10, ymm15, mem(rbx, 0*32)) + + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm12) + vmaskmovpd(ymm12, ymm15, mem(rdx, 0*32)) + add(rdi, rdx) + + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm14) + vmaskmovpd(ymm14, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + C_TRANSPOSE_6x3_TILE(4, 6, 8, 10, 12, 14) + jmp(.DDONE) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm14, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x3_TILE_BZ(4, 6, 8, 10, 12, 14) + jmp(.DDONE) + + label(.DDONE) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X3I) // iterate again if ii != 0. + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm10", "ymm12", "ymm14", + "ymm15", "memory" + ) + + consider_edge_cases_nleft_3: + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x3, + bli_dgemmsup_rv_haswell_asm_2x3, + bli_dgemmsup_rv_haswell_asm_3x3, + bli_dgemmsup_rv_haswell_asm_4x3, + bli_dgemmsup_rv_haswell_asm_5x3 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + +static void bli_dgemmsup_rv_haswell_asm_6x1m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// + +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded and 4th element will be set to 0 in destination vector. +// + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + int64_t const *mask_vec = mask_1; + + if ( m_iter == 0 ) goto consider_edge_cases_nleft_1; + + begin_asm() + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load mask + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X1I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm1) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 2*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 1 elements as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + //Loads 1 elements as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm1) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 1 elements as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + + // ---------------------------------- iteration 3 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 elements as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm1) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + vaddpd(ymm1, ymm14, ymm14) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 1 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm0, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm3, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm1) + vfmadd231pd(ymm1, ymm3, ymm6) + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm2) + vfmadd231pd(ymm2, ymm3, ymm8) + vmaskmovpd(ymm8, ymm15, mem(rbx, 0*32)) + add(rdi, rbx) + + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm4) + vfmadd231pd(ymm4, ymm3, ymm10) + vmaskmovpd(ymm10, ymm15, mem(rbx, 0*32)) + + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm5) + vfmadd231pd(ymm5, ymm3, ymm12) + vmaskmovpd(ymm12, ymm15, mem(rdx, 0*32)) + add(rdi, rdx) + + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm6) + vfmadd231pd(ymm6, ymm3, ymm14) + vmaskmovpd(ymm14, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + C_TRANSPOSE_6x1_TILE(4, 6, 8, 10, 12, 14) + jmp(.DDONE) + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm14, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_6x1_TILE_BZ(4, 6, 8, 10, 12, 14) + jmp(.DDONE) + + label(.DDONE) + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a8), rax) // load ps_a8 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a8 + + dec(r11) // ii -= 1; + jne(.DLOOP6X1I) // iterate again if ii != 0. + + label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [n0] "m" (n0), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" + ) + + consider_edge_cases_nleft_1: + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict ai = a + m_iter * ps_a; + double* restrict bj = b; + + dgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_dgemmsup_rv_haswell_asm_1x1, + bli_dgemmsup_rv_haswell_asm_2x1, + bli_dgemmsup_rv_haswell_asm_3x1, + bli_dgemmsup_rv_haswell_asm_4x1, + bli_dgemmsup_rv_haswell_asm_5x1 + }; + + dgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void bli_dgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a // read rbx from var(b) near beginning of loop // r11 = m dim index ii @@ -8545,6 +11091,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DROWSTORED) + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) @@ -8559,39 +11106,36 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) vmovupd(xmm7, mem(rcx, 1*32)) - add(rdi, rcx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) - vmovupd(xmm9, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 1*32), xmm3, xmm9) + vmovupd(xmm9, mem(rbx, 1*32)) + add(rdi, rbx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) - vmovupd(xmm11, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 1*32), xmm3, xmm11) + vmovupd(xmm11, mem(rbx, 1*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) - vmovupd(xmm13, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm13) + vmovupd(xmm13, mem(rdx, 1*32)) + add(rdi, rdx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) - vmovupd(ymm14, mem(rcx, 0*32)) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) - vmovupd(xmm15, mem(rcx, 1*32)) - //add(rdi, rcx) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rdx, 1*32)) jmp(.DDONE) // jump to end. @@ -8912,6 +11456,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } + void bli_dgemmsup_rv_haswell_asm_6x4m ( conj_t conja, @@ -9004,11 +11549,17 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) #endif mov(var(b), rbx) // load address of b. @@ -9110,19 +11661,19 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm15) // ---------------------------------- iteration 2 @@ -9167,27 +11718,31 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm15) dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + vaddpd(ymm15, ymm14, ymm14) @@ -9278,6 +11833,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DROWSTORED) + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) @@ -9286,27 +11842,24 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) - add(rdi, rcx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rbx, 0*32)) + add(rdi, rbx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rdx, 0*32)) + add(rdi, rdx) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) - vmovupd(ymm14, mem(rcx, 0*32)) - //add(rdi, rcx) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rdx, 0*32)) jmp(.DDONE) // jump to end. @@ -9483,8 +12036,8 @@ void bli_dgemmsup_rv_haswell_asm_6x4m "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "ymm6", "ymm8", "ymm10", "ymm12", "ymm14", - "memory" + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) consider_edge_cases: @@ -9664,11 +12217,17 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(xmm4, xmm4, xmm4) + vmovapd( ymm4, ymm5) vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) #endif mov(var(b), rbx) // load address of b. @@ -9766,19 +12325,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm8) - vfmadd231pd(xmm0, xmm3, xmm10) + vfmadd231pd(xmm0, xmm2, xmm9) + vfmadd231pd(xmm0, xmm3, xmm11) vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm12) - vfmadd231pd(xmm0, xmm3, xmm14) + vfmadd231pd(xmm0, xmm2, xmm13) + vfmadd231pd(xmm0, xmm3, xmm15) // ---------------------------------- iteration 2 @@ -9823,29 +12382,31 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm8) - vfmadd231pd(xmm0, xmm3, xmm10) + vfmadd231pd(xmm0, xmm2, xmm9) + vfmadd231pd(xmm0, xmm3, xmm11) vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm12) - vfmadd231pd(xmm0, xmm3, xmm14) + vfmadd231pd(xmm0, xmm2, xmm13) + vfmadd231pd(xmm0, xmm3, xmm15) dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + vaddpd(ymm15, ymm14, ymm14) label(.DCONSIDKLEFT) @@ -9933,6 +12494,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DROWSTORED) + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) vmovupd(xmm4, mem(rcx, 0*32)) @@ -9941,27 +12503,24 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) vmovupd(xmm6, mem(rcx, 0*32)) - add(rdi, rcx) - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) - vmovupd(xmm8, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm8) + vmovupd(xmm8, mem(rbx, 0*32)) + add(rdi, rbx) - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) - vmovupd(xmm10, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rbx, 0*32)) - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) - vmovupd(xmm12, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rdx, 0*32), xmm3, xmm12) + vmovupd(xmm12, mem(rdx, 0*32)) + add(rdi, rdx) - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) - vmovupd(xmm14, mem(rcx, 0*32)) - //add(rdi, rcx) + vfmadd231pd(mem(rdx, 0*32), xmm3, xmm14) + vmovupd(xmm14, mem(rdx, 0*32)) jmp(.DDONE) // jump to end. @@ -10118,8 +12677,8 @@ void bli_dgemmsup_rv_haswell_asm_6x2m "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "ymm6", "ymm8", "ymm10", "ymm12", "ymm14", - "memory" + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) consider_edge_cases: @@ -10206,5 +12765,3 @@ void bli_dgemmsup_rv_haswell_asm_6x2m } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } - - diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c index 4cdc763b67..9a9e362d29 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c index d1c251bcbd..e7aa4792d0 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c index af4ab52a02..bb639fad8b 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/CMakeLists.txt b/kernels/haswell/3/sup/d6x8/CMakeLists.txt deleted file mode 100644 index c74dff9372..0000000000 --- a/kernels/haswell/3/sup/d6x8/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(haswell_3supd6x8 - OBJECT -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_r_haswell_ref_dMx1.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx1.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx2.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx4.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_dMx8.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx2.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx4.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx6.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_dMx8.c - ) - -target_compile_options(haswell_3supd6x8 PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(haswell_3supd6x8 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c index 69d543a99d..23dadc4004 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_r_haswell_ref_dMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c index 6d9dd365ee..14c8e53b3c 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c index 94a8e9639e..5c73057817 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c index 01e2d0a3dd..2fb0c12e7f 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1342,17 +1342,6 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - - //vhaddpd( ymm8, ymm5, ymm0 ) - //vextractf128(imm(1), ymm0, xmm1 ) - //vaddpd( xmm0, xmm1, xmm0 ) - - //vhaddpd( ymm14, ymm11, ymm2 ) - //vextractf128(imm(1), ymm2, xmm1 ) - //vaddpd( xmm2, xmm1, xmm2 ) - - //vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) - // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c index 9b97a40a45..d643c2e996 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx1.c new file mode 100644 index 0000000000..42fa8c50a1 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx1.c @@ -0,0 +1,1984 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +#define C_TRANSPOSE_5x1_TILE(R1, R2, R3, R4, R5)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vmovupd(ymm(R1), mem(rcx ))\ +\ + vmovlpd(mem(rdx ), xmm0, xmm0)\ +\ + vfmadd213pd(ymm(R5), ymm3, ymm0)\ + vmovlpd(xmm0, mem(rdx ))\ + +#define C_TRANSPOSE_5x1_TILE_BZ(R1, R2, R3, R4, R5)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ +\ + vmovlpd(xmm(R5), mem(rdx ))\ + + +#define C_TRANSPOSE_4x1_TILE(R1, R2, R3, R4)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vmovupd(ymm(R1), mem(rcx ))\ + +#define C_TRANSPOSE_4x1_TILE_BZ(R1, R2, R3, R4)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vmovupd(ymm(R1), mem(rcx )) + +#define C_TRANSPOSE_3x1_TILE(R1, R2, R3)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm(R1))\ + vmovupd(xmm(R1), mem(rcx ))\ +\ + vfmadd231sd(mem(rdx ), xmm3, xmm12)\ + vmovsd(xmm12, mem(rdx )) + +#define C_TRANSPOSE_3x1_TILE_BZ(R1, R2, R3)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ +\ + vmovlpd(xmm(12), mem(rdx )) + +#define C_TRANSPOSE_2x1_TILE(R1, R2)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vmovupd(xmm0, mem(rcx )) + + +#define C_TRANSPOSE_2x1_TILE_BZ(R1, R2)\ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ +\ + vmovupd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x1_TILE(R1)\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd213pd(ymm(R1), ymm3, ymm0)\ +\ + vmovlpd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x1_TILE_BZ(R1)\ + vmovlpd(xmm(R1), mem(rcx )) + +static const int64_t mask_1[4] = {-1, 0, 0, 0}; + + +void bli_dgemmsup_rv_haswell_asm_5x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + lea(mem(r9, r9, 2), r15) // r15 = 3*cs_a + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm8) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm8) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, r15, 1, 5*8)) // a_prefetch += 3*cs_a; + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm8) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 3*rs_c; + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rax, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm3) + vmaskmovpd(mem(r8, 0*32), ymm15, ymm5) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm7) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + vfmadd231pd(ymm5, ymm1, ymm10) + vfmadd231pd(ymm7, ymm1, ymm12) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(r8, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_5x1_TILE(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_5x1_TILE_BZ(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm1, ymm13, ymm11) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm1, ymm13, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rax, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm3) + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm14) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + vfmadd231pd(ymm14, ymm1, ymm10) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_4x1_TILE(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_4x1_TILE_BZ(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 3*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm11) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm11) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm9, ymm4, ymm4) + vaddpd(ymm10, ymm6, ymm6) + vaddpd(ymm11, ymm8, ymm8) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rax, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm3) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_3x1_TILE(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_3x1_TILE_BZ(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 2*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 2*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm7, ymm4, ymm4) + vaddpd(ymm8, ymm6, ymm6) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm2) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_2x1_TILE(4, 6) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_2x1_TILE_BZ(4, 6) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x1 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_1 which is set to -1, 0, 0, 0 so that the +// 1 element will be loaded. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 1*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm7) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm8) + vfmadd231pd(ymm7, ymm8, ymm5) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 1*8)) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm7) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm8) + vfmadd231pd(ymm7, ymm8, ymm5) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_1x1_TILE(4) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_1x1_TILE_BZ(4) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm10", + "ymm12", "ymm15", "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c index 7c2fd21e1e..257712dc79 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -115,9 +115,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -132,7 +132,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -168,31 +168,31 @@ void bli_dgemmsup_rv_haswell_asm_6x2 prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -200,19 +200,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 1 #if 0 @@ -226,25 +226,25 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -252,18 +252,18 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 3 @@ -278,43 +278,43 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -322,57 +322,57 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) vmulpd(xmm0, xmm10, xmm10) vmulpd(xmm0, xmm12, xmm12) vmulpd(xmm0, xmm14, xmm14) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -381,42 +381,42 @@ void bli_dgemmsup_rv_haswell_asm_6x2 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -452,40 +452,40 @@ void bli_dgemmsup_rv_haswell_asm_6x2 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -517,13 +517,13 @@ void bli_dgemmsup_rv_haswell_asm_6x2 vmovupd(xmm1, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -589,9 +589,9 @@ void bli_dgemmsup_rv_haswell_asm_5x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -606,7 +606,7 @@ void bli_dgemmsup_rv_haswell_asm_5x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -647,21 +647,21 @@ void bli_dgemmsup_rv_haswell_asm_5x2 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 - + #if 1 prefetch(0, mem(rdx, 5*8)) #endif @@ -673,17 +673,17 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) - + // ---------------------------------- iteration 1 #if 0 @@ -697,23 +697,23 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) - + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -721,16 +721,16 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) - + // ---------------------------------- iteration 3 @@ -745,41 +745,41 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -787,54 +787,54 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) vmulpd(xmm0, xmm10, xmm10) vmulpd(xmm0, xmm12, xmm12) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // r13 = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -843,37 +843,27 @@ void bli_dgemmsup_rv_haswell_asm_5x2 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 2*rs_c; + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) - vmovupd(xmm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) - vmovupd(xmm6, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) - vmovupd(xmm8, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) - vmovupd(xmm10, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) - vmovupd(xmm12, mem(rcx, 0*32)) - //add(rdi, rcx) - + vfmadd231pd(mem(rax, 0*32), xmm3, xmm6) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm8) + vfmadd231pd(mem(r8, 0*32), xmm3, xmm10) + vfmadd231pd(mem(rdx, 0*32), xmm3, xmm12) + vmovupd(xmm4, mem(rcx, 0*32)) + vmovupd(xmm6, mem(rax, 0*32)) + vmovupd(xmm8, mem(rbx, 0*32)) + vmovupd(xmm10, mem(r8, 0*32)) + vmovupd(xmm12, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. @@ -908,37 +898,37 @@ void bli_dgemmsup_rv_haswell_asm_5x2 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm12, mem(rcx, 0*32)) //add(rdi, rcx) @@ -948,7 +938,7 @@ void bli_dgemmsup_rv_haswell_asm_5x2 label(.DCOLSTORBZ) - + // begin I/O on columns 0-1 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) @@ -968,13 +958,13 @@ void bli_dgemmsup_rv_haswell_asm_5x2 vmovhpd(xmm0, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1040,9 +1030,9 @@ void bli_dgemmsup_rv_haswell_asm_4x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1057,7 +1047,7 @@ void bli_dgemmsup_rv_haswell_asm_4x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1091,31 +1081,31 @@ void bli_dgemmsup_rv_haswell_asm_4x2 prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c - + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1123,14 +1113,14 @@ void bli_dgemmsup_rv_haswell_asm_4x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + // ---------------------------------- iteration 1 #if 0 @@ -1144,20 +1134,20 @@ void bli_dgemmsup_rv_haswell_asm_4x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1165,13 +1155,13 @@ void bli_dgemmsup_rv_haswell_asm_4x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + // ---------------------------------- iteration 3 @@ -1186,89 +1176,89 @@ void bli_dgemmsup_rv_haswell_asm_4x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) vmulpd(xmm0, xmm10, xmm10) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1277,32 +1267,25 @@ void bli_dgemmsup_rv_haswell_asm_4x2 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + lea(mem(rcx, rdi, 1), rax) // load address of c + 2*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vfmadd231pd(mem(rax, 0*32), xmm3, xmm6) + vfmadd231pd(mem(rdx, 0*32), xmm3, xmm8) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm10) + vmovupd(xmm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) - vmovupd(xmm6, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) - vmovupd(xmm8, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) - vmovupd(xmm10, mem(rcx, 0*32)) - //add(rdi, rcx) - - + vmovupd(xmm6, mem(rax, 0*32)) + vmovupd(xmm8, mem(rdx, 0*32)) + vmovupd(xmm10, mem(rbx, 0*32)) + + jmp(.DDONE) // jump to end. @@ -1328,32 +1311,32 @@ void bli_dgemmsup_rv_haswell_asm_4x2 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1364,7 +1347,7 @@ void bli_dgemmsup_rv_haswell_asm_4x2 label(.DCOLSTORBZ) - + // begin I/O on columns 0-1 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) @@ -1377,13 +1360,13 @@ void bli_dgemmsup_rv_haswell_asm_4x2 vmovupd(ymm6, mem(rcx, rsi, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1449,9 +1432,9 @@ void bli_dgemmsup_rv_haswell_asm_3x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1466,7 +1449,7 @@ void bli_dgemmsup_rv_haswell_asm_3x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1499,31 +1482,31 @@ void bli_dgemmsup_rv_haswell_asm_3x2 prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1531,12 +1514,12 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) - - + + // ---------------------------------- iteration 1 #if 0 @@ -1548,20 +1531,20 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) - + vfmadd231pd(xmm0, xmm2, xmm9) + vfmadd231pd(xmm0, xmm3, xmm10) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm8) - + vfmadd231pd(xmm0, xmm2, xmm11) + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1569,11 +1552,11 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) - + // ---------------------------------- iteration 3 @@ -1586,38 +1569,37 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) - + vfmadd231pd(xmm0, xmm2, xmm9) + vfmadd231pd(xmm0, xmm3, xmm10) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm8) - - - + vfmadd231pd(xmm0, xmm2, xmm11) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(xmm9, xmm4, xmm4) + vaddpd(xmm10, xmm6, xmm6) + vaddpd(xmm11, xmm8, xmm8) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -1625,78 +1607,73 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm8) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) - vmovupd(xmm4, mem(rcx, 0*32)) - add(rdi, rcx) + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) - vmovupd(xmm6, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm6) + vfmadd231pd(mem(rdx, 0*32), xmm3, xmm8) + + vmovupd(xmm4, mem(rcx, 0*32)) + vmovupd(xmm6, mem(rbx, 0*32)) + vmovupd(xmm8, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) - vmovupd(xmm8, mem(rcx, 0*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - + label(.DCOLSTORED) @@ -1725,26 +1702,26 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) - + //lea(mem(rdx, rsi, 4), rdx) jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) @@ -1755,8 +1732,8 @@ void bli_dgemmsup_rv_haswell_asm_3x2 vmovupd(xmm8, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1784,12 +1761,12 @@ void bli_dgemmsup_rv_haswell_asm_3x2 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1856,9 +1833,9 @@ void bli_dgemmsup_rv_haswell_asm_2x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1873,7 +1850,7 @@ void bli_dgemmsup_rv_haswell_asm_2x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1905,72 +1882,72 @@ void bli_dgemmsup_rv_haswell_asm_2x2 prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 5*8)) #endif - vmovupd(mem(rbx, 0*32), xmm0) + vmovupd(mem(rbx, 0*32), xmm9) add(r10, rbx) // b += rs_b; - - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) - - + vfmadd231pd(xmm9, xmm10, xmm7) + vfmadd231pd(xmm9, xmm11, xmm8) + + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + // ---------------------------------- iteration 3 @@ -1978,84 +1955,82 @@ void bli_dgemmsup_rv_haswell_asm_2x2 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), xmm0) + vmovupd(mem(rbx, 0*32), xmm9) add(r10, rbx) // b += rs_b; - - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm4) - vfmadd231pd(xmm0, xmm3, xmm6) - - - + vfmadd231pd(xmm9, xmm10, xmm7) + vfmadd231pd(xmm9, xmm11, xmm8) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(xmm7, xmm4, xmm4) + vaddpd(xmm8, xmm6, xmm6) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2064,22 +2039,20 @@ void bli_dgemmsup_rv_haswell_asm_2x2 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vfmadd231pd(mem(rbx, 0*32), xmm3, xmm6) + vmovupd(xmm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) - vmovupd(xmm6, mem(rcx, 0*32)) - //add(rdi, rcx) - - + vmovupd(xmm6, mem(rbx, 0*32)) + + jmp(.DDONE) // jump to end. @@ -2099,34 +2072,34 @@ void bli_dgemmsup_rv_haswell_asm_2x2 jmp(.DDONE) // jump to end. - - - + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) //add(rdi, rcx) - + jmp(.DDONE) // jump to end. label(.DCOLSTORBZ) - + vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) @@ -2135,13 +2108,13 @@ void bli_dgemmsup_rv_haswell_asm_2x2 vmovupd(xmm1, mem(rcx, rsi, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2169,7 +2142,7 @@ void bli_dgemmsup_rv_haswell_asm_2x2 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm2", "ymm3", - "memory" + "ymm10", "ymm11", "memory" ) } @@ -2207,9 +2180,9 @@ void bli_dgemmsup_rv_haswell_asm_1x2 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2224,7 +2197,7 @@ void bli_dgemmsup_rv_haswell_asm_1x2 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2255,31 +2228,31 @@ void bli_dgemmsup_rv_haswell_asm_1x2 prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2287,34 +2260,34 @@ void bli_dgemmsup_rv_haswell_asm_1x2 add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) - + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 5*8)) #endif - vmovupd(mem(rbx, 0*32), xmm0) + vmovupd(mem(rbx, 0*32), xmm6) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm7) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm4) - - + vfmadd231pd(xmm6, xmm7, xmm5) + + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) - + // ---------------------------------- iteration 3 @@ -2322,98 +2295,95 @@ void bli_dgemmsup_rv_haswell_asm_1x2 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), xmm0) + vmovupd(mem(rbx, 0*32), xmm6) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm7) add(r9, rax) // a += cs_a; - vfmadd231pd(xmm0, xmm2, xmm4) - - - + vfmadd231pd(xmm6, xmm7, xmm5) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(xmm5, xmm4, xmm4) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm4) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) vmovupd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2428,48 +2398,48 @@ void bli_dgemmsup_rv_haswell_asm_1x2 vmovlpd(xmm0, mem(rcx )) vmovhpd(xmm0, mem(rcx, rsi, 1)) - + //lea(mem(rcx, rsi, 4), rcx) jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx) jmp(.DDONE) // jump to end. - + label(.DCOLSTORBZ) - + // begin I/O on columns 0-1 vmovlpd(xmm4, mem(rcx )) vmovhpd(xmm4, mem(rcx, rsi, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2497,7 +2467,7 @@ void bli_dgemmsup_rv_haswell_asm_1x2 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm2", "ymm3", - "memory" + "ymm7", "memory" ) } diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c new file mode 100644 index 0000000000..3661ddf591 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx3.c @@ -0,0 +1,2078 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + + +#define C_TRANSPOSE_5x3_TILE(R1, R2, R3, R4, R5)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm3, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 4x1 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R5), ymm3, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_5x3_TILE_BZ(R1, R2, R3, R4, R5)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm3, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + + +#define C_TRANSPOSE_4x3_TILE(R1, R2, R3, R4)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), ymm3, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm(R3))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_4x3_TILE_BZ(R1, R2, R3, R4)\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_3x3_TILE(R1, R2, R3)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vunpckhpd(ymm(10), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm(R3))\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vfmadd231sd(mem(rdx ), xmm3, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm3, xmm14)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_3x3_TILE_BZ(R1, R2, R3)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(10), ymm(R3), ymm2)\ + vunpckhpd(ymm(10), ymm(R3), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vmovlpd(xmm(12), mem(rdx ))\ + vmovlpd(xmm(13), mem(rdx, rsi, 1))\ + vmovlpd(xmm(14), mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_2x3_TILE(R1, R2)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_2x3_TILE_BZ(R1, R2)\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_1x3_TILE(R1)\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd213pd(ymm(R1), ymm3, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_1x3_TILE_BZ(R1)\ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +static const int64_t mask_3[4] = {-1, -1, -1, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + lea(mem(r9, r9, 2), r15) // r15 = 3*cs_a + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm14) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm14, ymm8) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm1, ymm3, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm14) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm14, ymm9) + vfmadd231pd(ymm1, ymm2, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm1, ymm3, ymm13) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm14) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm14, ymm8) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm1, ymm3, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + prefetch(0, mem(rdx, r15, 1, 5*8)) // a_prefetch += 3*cs_a; + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm14) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm14, ymm9) + vfmadd231pd(ymm1, ymm2, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm1, ymm3, ymm13) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm14) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm1, ymm14, ymm8) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm3) + vfmadd231pd(ymm1, ymm3, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 3*rs_c; + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rax, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm3) + vmaskmovpd(mem(r8, 0*32), ymm15, ymm5) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm7) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + vfmadd231pd(ymm5, ymm1, ymm10) + vfmadd231pd(ymm7, ymm1, ymm12) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(r8, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_5x3_TILE(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm12, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_5x3_TILE_BZ(4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_4x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm1, ymm13, ymm11) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm1, ymm13, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) + vfmadd231pd(ymm1, ymm12, ymm8) + vfmadd231pd(ymm1, ymm13, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + lea(mem(rcx, rdi, 1), rax) // load address of c + 2*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rax, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm3) + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm14) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + vfmadd231pd(ymm14, ymm1, ymm10) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_4x3_TILE(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm10, ymm15, mem(rcx, 0*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_4x3_TILE_BZ(4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm12", + "ymm11", "ymm13", "ymm14", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 3*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm11) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vfmadd231pd(ymm1, ymm12, ymm11) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm9, ymm4, ymm4) + vaddpd(ymm10, ymm6, ymm6) + vaddpd(ymm11, ymm8, ymm8) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm1, ymm2, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm2) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm3) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + vfmadd231pd(ymm3, ymm1, ymm8) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rbx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_3x3_TILE(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmaskmovpd(ymm8, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_3x3_TILE_BZ(4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 2*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 2*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 2*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 2*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 2*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + vfmadd231pd(ymm1, ymm3, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm7, ymm4, ymm4) + vaddpd(ymm8, ymm6, ymm6) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm1, ymm2, ymm6) + + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vmaskmovpd(mem(rdx, 0*32), ymm15, ymm2) + + vfmadd231pd(ymm0, ymm1, ymm4) + vfmadd231pd(ymm2, ymm1, ymm6) + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rdx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_2x3_TILE(4, 6) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + add(rdi, rcx) + + vmaskmovpd(ymm6, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORBZ) + + C_TRANSPOSE_2x3_TILE_BZ(4, 6) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x3 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 1*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm7) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm8) + vfmadd231pd(ymm7, ymm8, ymm5) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 1*8)) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 2), rdx) // a_prefetch += 2*cs_a; + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 3*cs_a; + prefetch(0, mem(rdx, 4*8)) + lea(mem(rdx, r9, 1), rdx) // a_prefetch += 4*cs_a; +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm7) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm8) + vfmadd231pd(ymm7, ymm8, ymm5) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm4, ymm4) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 0*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vmaskmovpd(mem(rcx, 0*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_1x3_TILE(4) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmaskmovpd(ymm4, ymm15, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_1x3_TILE_BZ(4) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm10", + "ymm12", "ymm15", "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c index ad43e7ba57..76297df673 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -612,9 +612,9 @@ void bli_dgemmsup_rv_haswell_asm_5x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -629,7 +629,7 @@ void bli_dgemmsup_rv_haswell_asm_5x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -672,19 +672,19 @@ void bli_dgemmsup_rv_haswell_asm_5x4 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -698,17 +698,17 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - + // ---------------------------------- iteration 1 #if 0 @@ -720,18 +720,18 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - + vfmadd231pd(ymm0, ymm2, ymm13) + // ---------------------------------- iteration 2 @@ -746,16 +746,16 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - + // ---------------------------------- iteration 3 @@ -768,43 +768,47 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - - - + vfmadd231pd(ymm0, ymm2, ymm13) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + vaddpd(ymm13, ymm12, ymm12) + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -812,54 +816,54 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm10, ymm10) vmulpd(ymm0, ymm12, ymm12) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -868,37 +872,27 @@ void bli_dgemmsup_rv_haswell_asm_5x4 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + lea(mem(rcx, rdi, 1), rax) // load address of c + 2*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 2*rs_c; + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(r8, 0*32), ymm3, ymm10) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vmovupd(ymm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) - //add(rdi, rcx) - - + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(ymm8, mem(rbx, 0*32)) + vmovupd(ymm10, mem(r8, 0*32)) + vmovupd(ymm12, mem(rdx, 0*32)) + jmp(.DDONE) // jump to end. @@ -945,41 +939,41 @@ void bli_dgemmsup_rv_haswell_asm_5x4 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) //add(rdi, rcx) - + jmp(.DDONE) // jump to end. @@ -1012,13 +1006,13 @@ void bli_dgemmsup_rv_haswell_asm_5x4 vmovhpd(xmm1, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1046,8 +1040,8 @@ void bli_dgemmsup_rv_haswell_asm_5x4 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "ymm6", "ymm8", "ymm10", "ymm12", - "memory" + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm13", "memory" ) } @@ -1085,9 +1079,9 @@ void bli_dgemmsup_rv_haswell_asm_4x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1102,7 +1096,7 @@ void bli_dgemmsup_rv_haswell_asm_4x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1138,8 +1132,8 @@ void bli_dgemmsup_rv_haswell_asm_4x4 prefetch(0, mem(rcx, rdx, 1, 3*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; @@ -1147,22 +1141,22 @@ void bli_dgemmsup_rv_haswell_asm_4x4 - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1170,14 +1164,14 @@ void bli_dgemmsup_rv_haswell_asm_4x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - - + vfmadd231pd(ymm0, ymm12, ymm8) + vfmadd231pd(ymm0, ymm13, ymm10) + + // ---------------------------------- iteration 1 #if 0 @@ -1186,39 +1180,39 @@ void bli_dgemmsup_rv_haswell_asm_4x4 vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - + vfmadd231pd(ymm0, ymm12, ymm9) + vfmadd231pd(ymm0, ymm13, ymm11) + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - + vfmadd231pd(ymm0, ymm12, ymm8) + vfmadd231pd(ymm0, ymm13, ymm10) + // ---------------------------------- iteration 3 @@ -1228,128 +1222,123 @@ void bli_dgemmsup_rv_haswell_asm_4x4 vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - - - + vfmadd231pd(ymm0, ymm12, ymm9) + vfmadd231pd(ymm0, ymm13, ymm11) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + vaddpd(ymm5, ymm4, ymm4) + vaddpd(ymm7, ymm6, ymm6) + vaddpd(ymm9, ymm8, ymm8) + vaddpd(ymm11, ymm10, ymm10) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) + + vbroadcastsd(mem(rax, r8, 2), ymm12) + vbroadcastsd(mem(rax, r13, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm0, ymm3, ymm10) - - + vfmadd231pd(ymm0, ymm12, ymm8) + vfmadd231pd(ymm0, ymm13, ymm10) + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm10, ymm10) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case label(.DROWSTORED) - - + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vmovupd(ymm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) - //add(rdi, rcx) - - + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(ymm8, mem(rdx, 0*32)) + vmovupd(ymm10, mem(rbx, 0*32)) + + jmp(.DDONE) // jump to end. @@ -1381,33 +1370,33 @@ void bli_dgemmsup_rv_haswell_asm_4x4 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1417,7 +1406,7 @@ void bli_dgemmsup_rv_haswell_asm_4x4 label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) @@ -1435,12 +1424,12 @@ void bli_dgemmsup_rv_haswell_asm_4x4 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1468,8 +1457,8 @@ void bli_dgemmsup_rv_haswell_asm_4x4 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "ymm6", "ymm8", "ymm10", - "memory" + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm13", "memory" ) } @@ -1507,9 +1496,9 @@ void bli_dgemmsup_rv_haswell_asm_3x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1524,7 +1513,7 @@ void bli_dgemmsup_rv_haswell_asm_3x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1559,31 +1548,31 @@ void bli_dgemmsup_rv_haswell_asm_3x4 prefetch(0, mem(rcx, rdx, 1, 2*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1591,12 +1580,12 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) + + vbroadcastsd(mem(rax, r8, 2), ymm12) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - - + vfmadd231pd(ymm0, ymm12, ymm8) + + // ---------------------------------- iteration 1 #if 0 @@ -1608,20 +1597,20 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - + vfmadd231pd(ymm0, ymm12, ymm11) + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1629,11 +1618,11 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) + + vbroadcastsd(mem(rax, r8, 2), ymm12) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - + vfmadd231pd(ymm0, ymm12, ymm8) + // ---------------------------------- iteration 3 @@ -1646,38 +1635,37 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - - vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 2), ymm12) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - - - + vfmadd231pd(ymm0, ymm12, ymm11) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm9, ymm4, ymm4) + vaddpd(ymm10, ymm6, ymm6) + vaddpd(ymm11, ymm8, ymm8) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1685,78 +1673,73 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) - add(rdi, rcx) + lea(mem(rcx, rdi, 1), rbx) // load address of c + 2*rs_c; - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm6, mem(rbx, 0*32)) + vmovupd(ymm8, mem(rdx, 0*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - + label(.DCOLSTORED) @@ -1797,26 +1780,26 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) - + //lea(mem(rdx, rsi, 4), rdx) jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) @@ -1827,8 +1810,8 @@ void bli_dgemmsup_rv_haswell_asm_3x4 vmovupd(ymm8, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1864,12 +1847,12 @@ void bli_dgemmsup_rv_haswell_asm_3x4 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1897,8 +1880,8 @@ void bli_dgemmsup_rv_haswell_asm_3x4 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "ymm6", "ymm8", "ymm10", - "memory" + "ymm6", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "memory" ) } @@ -1936,9 +1919,9 @@ void bli_dgemmsup_rv_haswell_asm_2x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1953,7 +1936,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1987,31 +1970,31 @@ void bli_dgemmsup_rv_haswell_asm_2x4 prefetch(0, mem(rcx, rdx, 1, 1*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -2020,26 +2003,26 @@ void bli_dgemmsup_rv_haswell_asm_2x4 add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - + + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 4*8)) #endif - vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 0*32), ymm9) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -2052,7 +2035,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + // ---------------------------------- iteration 3 @@ -2060,40 +2043,38 @@ void bli_dgemmsup_rv_haswell_asm_2x4 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 0*32), ymm9) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + vbroadcastsd(mem(rax ), ymm10) + vbroadcastsd(mem(rax, r8, 1), ymm11) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm0, ymm3, ymm6) - - - + vfmadd231pd(ymm9, ymm10, ymm7) + vfmadd231pd(ymm9, ymm11, ymm8) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm7, ymm4, ymm4) + vaddpd(ymm8, ymm6, ymm6) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -2102,42 +2083,42 @@ void bli_dgemmsup_rv_haswell_asm_2x4 add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2146,22 +2127,19 @@ void bli_dgemmsup_rv_haswell_asm_2x4 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*cs_c; + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm6) + vmovupd(ymm4, mem(rcx, 0*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - //add(rdi, rcx) - - + vmovupd(ymm6, mem(rdx, 0*32)) + jmp(.DDONE) // jump to end. @@ -2187,24 +2165,24 @@ void bli_dgemmsup_rv_haswell_asm_2x4 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2215,7 +2193,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) @@ -2228,13 +2206,13 @@ void bli_dgemmsup_rv_haswell_asm_2x4 vmovupd(xmm4, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2262,7 +2240,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", - "memory" + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) } @@ -2300,9 +2278,9 @@ void bli_dgemmsup_rv_haswell_asm_1x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2317,7 +2295,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2350,27 +2328,27 @@ void bli_dgemmsup_rv_haswell_asm_1x4 prefetch(0, mem(rcx, rdx, 1, 0*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 - + #if 1 prefetch(0, mem(rdx, 4*8)) #endif @@ -2381,35 +2359,35 @@ void bli_dgemmsup_rv_haswell_asm_1x4 vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) - - + + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 4*8)) #endif - vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 0*32), ymm1) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - + vfmadd231pd(ymm1, ymm3, ymm5) + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif - - vmovupd(mem(rbx, 0*32), ymm0) + + vmovupd(mem(rbx, 0*32), ymm6) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm7) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - + vfmadd231pd(ymm6, ymm7, ymm4) + // ---------------------------------- iteration 3 @@ -2417,33 +2395,30 @@ void bli_dgemmsup_rv_haswell_asm_1x4 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 0*32), ymm8) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm9) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - - - + vfmadd231pd(ymm8, ymm9, ymm5) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm5, ymm4, ymm4) + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -2455,41 +2430,41 @@ void bli_dgemmsup_rv_haswell_asm_1x4 vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2497,17 +2472,17 @@ void bli_dgemmsup_rv_haswell_asm_1x4 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2529,15 +2504,15 @@ void bli_dgemmsup_rv_haswell_asm_1x4 vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) - + //lea(mem(rcx, rsi, 4), rcx) jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) @@ -2545,10 +2520,10 @@ void bli_dgemmsup_rv_haswell_asm_1x4 jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2558,7 +2533,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4 label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vmovupd(ymm4, ymm0) @@ -2569,14 +2544,14 @@ void bli_dgemmsup_rv_haswell_asm_1x4 vmovhpd(xmm1, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - - - + + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2604,7 +2579,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", - "memory" + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "memory" ) } diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c new file mode 100644 index 0000000000..b9473fff27 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx5.c @@ -0,0 +1,2499 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +//3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14 +#define C_TRANSPOSE_5x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + /*Transposing 4x1 tile*/ \ + vfmadd213pd(ymm(R5), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R6))\ + vmovupd(ymm(R6), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R10), ymm15, ymm0)\ + vmovlpd(xmm0, mem(rdx )) + +#define C_TRANSPOSE_5x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 1x4 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ +\ + vmovupd(ymm(R6), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovlpd(xmm(R10), mem(rdx )) + + +#define C_TRANSPOSE_4x5_TILE(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R5))\ + vmovupd(ymm(R5), mem(rcx )) + +#define C_TRANSPOSE_4x5_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x1 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ +\ + vmovupd(ymm(R5), mem(rcx )) + +//3, 5, 7, 4, 6, 8 +#define C_TRANSPOSE_3x5_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vbroadcastsd(mem(rbx), ymm11)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), xmm11, xmm10)\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vfmadd231sd(mem(rdx, rax, 1), xmm11, xmm15)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ +\ + vextractf128(imm(0x1), ymm(R4), xmm12)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm(R4))\ + vmovupd(xmm(R4), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vfmadd231sd(mem(rdx ), xmm3, xmm12)\ + vmovsd(xmm12, mem(rdx )) + +#define C_TRANSPOSE_3x5_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + /*Transposing 1x4 tile*/ \ + lea(mem(rcx, rsi, 4), rcx)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ +\ + vextractf128(imm(0x1), ymm(R4), xmm12)\ +\ + vmovupd(xmm(R4), mem(rcx ))\ +\ + /*Transposing 1x1 tile*/ \ + vmovsd(xmm12, mem(rdx )) + +//3, 5, 4, 6 +#define C_TRANSPOSE_2x5_TILE(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm7)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vmovupd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_2x5_TILE_BZ(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x1 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ +\ + vmovupd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x5_TILE(R1, R2) \ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm15)\ + vfmadd213pd(ymm(R1), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R2), ymm15, ymm0)\ +\ + /*Transposing 1x1 tile*/ \ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx )) + +#define C_TRANSPOSE_1x5_TILE_BZ(R1, R2) \ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vextractf128(imm(1), ymm(R2), xmm1)\ + vmovlpd(xmm(R2), mem(rcx )) + +static const int64_t mask_1[4] = {-1, 0, 0, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 4*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 4*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 2*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 2*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rax, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rax, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vfmadd231pd(mem(r8, 0*32), ymm1, ymm9) + vmaskmovpd(mem(r8, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rax, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 1*32)) + + vmovupd(ymm7, mem(rbx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 1*32)) + + vmovupd(ymm9, mem(r8, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(r8, 1*32)) + + vmovupd(ymm11, mem(rdx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_5x5_TILE(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_5x5_TILE_BZ(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm12", "ymm11", "ymm15", + "memory" + ) +} + + +void bli_dgemmsup_rv_haswell_asm_4x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 3*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rax, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rax, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rax, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 1*32)) + + vmovupd(ymm7, mem(rdx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 1*32)) + + vmovupd(ymm9, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_4x5_TILE(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_4x5_TILE_BZ(3, 5, 7, 9, 4, 6, 8, 10) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 4*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 2*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm9, ymm3, ymm3) + vaddpd(ymm10, ymm4, ymm4) + vaddpd(ymm11, ymm5, ymm5) + vaddpd(ymm12, ymm6, ymm6) + vaddpd(ymm13, ymm7, ymm7) + vaddpd(ymm14, ymm8, ymm8) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rbx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rbx, 1*32)) + + vmovupd(ymm7, mem(rdx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_3x5_TILE(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_3x5_TILE_BZ(3, 5, 7, 4, 6, 8) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 4*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 1*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm0, ymm11, ymm5) + vfmadd231pd(ymm1, ymm11, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm0, ymm11, ymm9) + vfmadd231pd(ymm1, ymm11, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm0, ymm11, ymm5) + vfmadd231pd(ymm1, ymm11, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm0, ymm11, ymm9) + vfmadd231pd(ymm1, ymm11, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm7, ymm3, ymm3) + vaddpd(ymm8, ymm4, ymm4) + vaddpd(ymm9, ymm5, ymm5) + vaddpd(ymm10, ymm6, ymm6) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm11) + vfmadd231pd(ymm0, ymm11, ymm5) + vfmadd231pd(ymm1, ymm11, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm2) + vfmadd231pd(ymm2, ymm1, ymm6) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rdx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rdx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_2x5_TILE(3, 5, 4, 6) + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_2x5_TILE_BZ(3, 5, 4, 6) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x5 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | 0 | 0 | 0 | ----> Mask vector( mask_1 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 0 | 0 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 5 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 1 element, +// kernel is using mask_1 which is set to -1, 0, 0, 0 static that the +// 1 element will be loaded and other 3 elements will be set to 0 in destination vector. +// + int64_t const *mask_vec = mask_1; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rcx, rdx, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 2, 0*8)) // prefetch c + 4*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm8) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm7) + vfmadd231pd(ymm8, ymm7, ymm5) + vfmadd231pd(ymm9, ymm7, ymm6) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm8) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm7) + vfmadd231pd(ymm8, ymm7, ymm5) + vfmadd231pd(ymm9, ymm7, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm5, ymm3, ymm3) + vaddpd(ymm6, ymm4, ymm4) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 1 element as per mask_1 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + C_TRANSPOSE_1x5_TILE(3, 4) + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + C_TRANSPOSE_1x5_TILE_BZ(3, 4) + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm12", + "ymm15", + "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c index 9f80ef2f0d..e2bdeba8da 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx6.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -115,15 +115,15 @@ void bli_dgemmsup_rv_haswell_asm_6x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -180,18 +180,18 @@ void bli_dgemmsup_rv_haswell_asm_6x6 lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -208,14 +208,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -224,7 +224,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 1 #if 0 @@ -241,14 +241,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -256,8 +256,8 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + // ---------------------------------- iteration 2 #if 1 @@ -274,14 +274,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -289,7 +289,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 3 @@ -307,14 +307,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -322,50 +322,50 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -373,22 +373,22 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(ymm0, ymm6, ymm6) @@ -401,24 +401,24 @@ void bli_dgemmsup_rv_haswell_asm_6x6 vmulpd(xmm0, xmm13, xmm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(xmm0, xmm15, xmm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -427,60 +427,60 @@ void bli_dgemmsup_rv_haswell_asm_6x6 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -555,51 +555,51 @@ void bli_dgemmsup_rv_haswell_asm_6x6 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -656,12 +656,12 @@ void bli_dgemmsup_rv_haswell_asm_6x6 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -729,15 +729,15 @@ void bli_dgemmsup_rv_haswell_asm_5x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -793,18 +793,18 @@ void bli_dgemmsup_rv_haswell_asm_5x6 lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -821,20 +821,20 @@ void bli_dgemmsup_rv_haswell_asm_5x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - + // ---------------------------------- iteration 1 #if 0 @@ -851,20 +851,20 @@ void bli_dgemmsup_rv_haswell_asm_5x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - + + // ---------------------------------- iteration 2 #if 1 @@ -881,19 +881,19 @@ void bli_dgemmsup_rv_haswell_asm_5x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - + // ---------------------------------- iteration 3 @@ -911,82 +911,82 @@ void bli_dgemmsup_rv_haswell_asm_5x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(ymm0, ymm6, ymm6) @@ -997,24 +997,24 @@ void bli_dgemmsup_rv_haswell_asm_5x6 vmulpd(xmm0, xmm11, xmm11) vmulpd(ymm0, ymm12, ymm12) vmulpd(xmm0, xmm13, xmm13) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1023,52 +1023,46 @@ void bli_dgemmsup_rv_haswell_asm_5x6 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 3*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rax, 1*32), xmm3, xmm7) + + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rbx, 1*32), xmm3, xmm9) + + vfmadd231pd(mem(r8, 0*32), ymm3, ymm10) + vfmadd231pd(mem(r8, 1*32), xmm3, xmm11) + + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm13) + + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) - vmovupd(xmm7, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(xmm7, mem(rax, 1*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) - vmovupd(xmm9, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm8, mem(rbx, 0*32)) + vmovupd(xmm9, mem(rbx, 1*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) - vmovupd(xmm11, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm10, mem(r8, 0*32)) + vmovupd(xmm11, mem(r8, 1*32)) + + vmovupd(ymm12, mem(rdx, 0*32)) + vmovupd(xmm13, mem(rdx, 1*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) - vmovupd(xmm13, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. @@ -1141,46 +1135,46 @@ void bli_dgemmsup_rv_haswell_asm_5x6 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(xmm13, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1234,12 +1228,12 @@ void bli_dgemmsup_rv_haswell_asm_5x6 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1306,15 +1300,15 @@ void bli_dgemmsup_rv_haswell_asm_4x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -1370,17 +1364,17 @@ void bli_dgemmsup_rv_haswell_asm_4x6 #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -1397,7 +1391,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1406,7 +1400,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + // ---------------------------------- iteration 1 #if 0 @@ -1423,7 +1417,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1431,8 +1425,8 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - + + // ---------------------------------- iteration 2 #if 1 @@ -1449,7 +1443,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1457,7 +1451,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + // ---------------------------------- iteration 3 @@ -1475,7 +1469,7 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1483,43 +1477,43 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1527,22 +1521,22 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(ymm0, ymm6, ymm6) @@ -1551,24 +1545,24 @@ void bli_dgemmsup_rv_haswell_asm_4x6 vmulpd(xmm0, xmm9, xmm9) vmulpd(ymm0, ymm10, ymm10) vmulpd(xmm0, xmm11, xmm11) - - - - - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; - //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1577,44 +1571,40 @@ void bli_dgemmsup_rv_haswell_asm_4x6 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rax, 1*32), xmm3, xmm7) + + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm9) + + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vfmadd231pd(mem(rbx, 1*32), xmm3, xmm11) + + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) - vmovupd(xmm7, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(xmm7, mem(rax, 1*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) - vmovupd(xmm9, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm8, mem(rdx, 0*32)) + vmovupd(xmm9, mem(rdx, 1*32)) + + vmovupd(ymm10, mem(rbx, 0*32)) + vmovupd(xmm11, mem(rbx, 1*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) - vmovupd(xmm11, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. @@ -1663,41 +1653,41 @@ void bli_dgemmsup_rv_haswell_asm_4x6 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(xmm11, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1734,9 +1724,9 @@ void bli_dgemmsup_rv_haswell_asm_4x6 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) @@ -1806,9 +1796,9 @@ void bli_dgemmsup_rv_haswell_asm_3x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1823,7 +1813,7 @@ void bli_dgemmsup_rv_haswell_asm_3x6 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1861,31 +1851,31 @@ void bli_dgemmsup_rv_haswell_asm_3x6 prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; @@ -1896,13 +1886,13 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - - + + // ---------------------------------- iteration 1 #if 0 @@ -1915,19 +1905,19 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - + vfmadd231pd(ymm0, ymm2, ymm14) + vfmadd231pd(ymm1, ymm2, ymm15) + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -1942,12 +1932,12 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - + // ---------------------------------- iteration 3 @@ -1961,41 +1951,44 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - - - + vfmadd231pd(ymm0, ymm2, ymm14) + vfmadd231pd(ymm1, ymm2, ymm15) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm10, ymm4, ymm4) + vaddpd(ymm11, ymm5, ymm5) + vaddpd(ymm12, ymm6, ymm6) + vaddpd(ymm13, ymm7, ymm7) + vaddpd(ymm14, ymm8, ymm8) + vaddpd(ymm15, ymm9, ymm9) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; @@ -2006,91 +1999,83 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(ymm0, ymm6, ymm6) vmulpd(xmm0, xmm7, xmm7) vmulpd(ymm0, ymm8, ymm8) vmulpd(xmm0, xmm9, xmm9) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) - - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) - vmovupd(xmm5, mem(rcx, 1*32)) - add(rdi, rcx) + label(.DROWSTORED) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) - vmovupd(xmm7, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rbx, 1*32), xmm3, xmm7) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm9) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + vmovupd(ymm6, mem(rbx, 0*32)) + vmovupd(xmm7, mem(rbx, 1*32)) + vmovupd(ymm8, mem(rdx, 0*32)) + vmovupd(xmm9, mem(rdx, 1*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) - vmovupd(xmm9, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - + label(.DCOLSTORED) @@ -2131,7 +2116,7 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) - + lea(mem(rdx, rsi, 4), rdx) // begin I/O on columns 4-5 @@ -2162,26 +2147,26 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vfmadd231sd(mem(rdx, rsi, 1), xmm3, xmm13) vmovsd(xmm12, mem(rdx )) vmovsd(xmm13, mem(rdx, rsi, 1)) - + //lea(mem(rdx, rsi, 4), rdx) jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) @@ -2193,8 +2178,8 @@ void bli_dgemmsup_rv_haswell_asm_3x6 vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2251,12 +2236,12 @@ void bli_dgemmsup_rv_haswell_asm_3x6 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2285,7 +2270,7 @@ void bli_dgemmsup_rv_haswell_asm_3x6 "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", - "memory" + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -2323,15 +2308,15 @@ void bli_dgemmsup_rv_haswell_asm_2x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -2385,17 +2370,17 @@ void bli_dgemmsup_rv_haswell_asm_2x6 #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -2414,7 +2399,7 @@ void bli_dgemmsup_rv_haswell_asm_2x6 vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + // ---------------------------------- iteration 1 #if 0 @@ -2425,15 +2410,15 @@ void bli_dgemmsup_rv_haswell_asm_2x6 vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + vbroadcastsd(mem(rax ), ymm12) + vbroadcastsd(mem(rax, r8, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - - + vfmadd231pd(ymm0, ymm12, ymm8) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm0, ymm13, ymm10) + vfmadd231pd(ymm1, ymm13, ymm11) + + // ---------------------------------- iteration 2 #if 1 @@ -2451,7 +2436,7 @@ void bli_dgemmsup_rv_haswell_asm_2x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + // ---------------------------------- iteration 3 @@ -2463,43 +2448,44 @@ void bli_dgemmsup_rv_haswell_asm_2x6 vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) + vbroadcastsd(mem(rax ), ymm12) + vbroadcastsd(mem(rax, r8, 1), ymm13) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - - - + vfmadd231pd(ymm0, ymm12, ymm8) + vfmadd231pd(ymm1, ymm12, ymm9) + vfmadd231pd(ymm0, ymm13, ymm10) + vfmadd231pd(ymm1, ymm13, ymm11) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm8, ymm4, ymm4) + vaddpd(ymm9, ymm5, ymm5) + vaddpd(ymm10, ymm6, ymm6) + vaddpd(ymm11, ymm7, ymm7) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; @@ -2507,44 +2493,44 @@ void bli_dgemmsup_rv_haswell_asm_2x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(ymm0, ymm6, ymm6) vmulpd(xmm0, xmm7, xmm7) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2553,28 +2539,24 @@ void bli_dgemmsup_rv_haswell_asm_2x6 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rdx, 1*32), xmm3, xmm7) + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm6, mem(rdx, 0*32)) + vmovupd(xmm7, mem(rdx, 1*32)) + - vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) - vmovupd(xmm7, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. @@ -2623,31 +2605,31 @@ void bli_dgemmsup_rv_haswell_asm_2x6 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2684,9 +2666,9 @@ void bli_dgemmsup_rv_haswell_asm_2x6 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) @@ -2718,7 +2700,7 @@ void bli_dgemmsup_rv_haswell_asm_2x6 "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", - "memory" + "ymm12", "ymm13", "memory" ) } @@ -2756,15 +2738,15 @@ void bli_dgemmsup_rv_haswell_asm_1x6 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + //lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -2817,17 +2799,17 @@ void bli_dgemmsup_rv_haswell_asm_1x6 #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -2843,7 +2825,7 @@ void bli_dgemmsup_rv_haswell_asm_1x6 vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - + // ---------------------------------- iteration 1 #if 0 @@ -2854,12 +2836,12 @@ void bli_dgemmsup_rv_haswell_asm_1x6 vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - - + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + // ---------------------------------- iteration 2 #if 1 @@ -2874,7 +2856,7 @@ void bli_dgemmsup_rv_haswell_asm_1x6 add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - + // ---------------------------------- iteration 3 @@ -2886,80 +2868,79 @@ void bli_dgemmsup_rv_haswell_asm_1x6 vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax ), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - - - + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + vaddpd(ymm6, ymm4, ymm4) + vaddpd(ymm7, ymm5, ymm5) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2968,20 +2949,20 @@ void bli_dgemmsup_rv_haswell_asm_1x6 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) vmovupd(xmm5, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -3018,26 +2999,26 @@ void bli_dgemmsup_rv_haswell_asm_1x6 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -3063,9 +3044,9 @@ void bli_dgemmsup_rv_haswell_asm_1x6 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) @@ -3096,7 +3077,7 @@ void bli_dgemmsup_rv_haswell_asm_1x6 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", - "memory" + "ymm6", "ymm7", "memory" ) } diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c new file mode 100644 index 0000000000..be22b32b41 --- /dev/null +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx7.c @@ -0,0 +1,2584 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +//3, 5, 7, 9, 11, 13, 4, 6, 8, 10, 12, 14 +#define C_TRANSPOSE_5x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + /*Broadcasting Beta into ymm15 vector register*/\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + /*Scaling C matrix by Beta and adding it to fma result.*/ \ + /*R1, R2, R3, R4 holds final result*/ \ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + /*Storing it back to C matrix.*/ \ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + /*Moving to operate on last 1 row of 5 rows.*/ \ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rdx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R5), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpckhpd(ymm(R7), ymm(R6), ymm1)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vunpckhpd(ymm(R9), ymm(R8), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R7))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R8))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R6))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R7))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R8))\ + vmovupd(ymm(R6), mem(rcx ))\ + vmovupd(ymm(R7), mem(rcx, rsi, 1))\ + vmovupd(ymm(R8), mem(rcx, rsi, 2))\ +\ + vmovlpd(mem(rdx ), xmm0, xmm0)\ + vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rdx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + /*Transposing 1x3 tile*/ \ + vfmadd213pd(ymm(R10), ymm15, ymm0)\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rdx ))\ + vmovhpd(xmm0, mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_5x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8, R9, R10) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(1), ymm(R5), xmm1)\ + vmovlpd(xmm(R5), mem(rdx ))\ + vmovhpd(xmm(R5), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2))\ + vmovhpd(xmm1, mem(rdx, rax, 1))\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R7), ymm(R6), ymm0)\ + vunpckhpd(ymm(R7), ymm(R6), ymm1)\ + vunpcklpd(ymm(R9), ymm(R8), ymm2)\ + vunpckhpd(ymm(R9), ymm(R8), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R6))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R7))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R8))\ +\ + vmovupd(ymm(R6), mem(rcx ))\ + vmovupd(ymm(R7), mem(rcx, rsi, 1))\ + vmovupd(ymm(R8), mem(rcx, rsi, 2))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(1), ymm(R10), xmm1)\ + vmovlpd(xmm(R10), mem(rdx ))\ + vmovhpd(xmm(R10), mem(rdx, rsi, 1))\ + vmovlpd(xmm1, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_4x7_TILE(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vbroadcastsd(mem(rbx), ymm15)\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), ymm15, ymm(R4))\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vunpckhpd(ymm(R8), ymm(R7), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R6))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R7))\ +\ + vfmadd231pd(mem(rcx ), ymm15, ymm(R5))\ + vfmadd231pd(mem(rcx, rsi, 1), ymm15, ymm(R6))\ + vfmadd231pd(mem(rcx, rsi, 2), ymm15, ymm(R7))\ + vmovupd(ymm(R5), mem(rcx ))\ + vmovupd(ymm(R6), mem(rcx, rsi, 1))\ + vmovupd(ymm(R7), mem(rcx, rsi, 2)) + +#define C_TRANSPOSE_4x7_TILE_BZ(R1, R2, R3, R4, R5, R6, R7, R8) \ + /*Transposing 4x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm(R4), ymm(R3), ymm2)\ + vunpckhpd(ymm(R4), ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm(R4))\ +\ + vmovupd(ymm(R1), mem(rcx ))\ + vmovupd(ymm(R2), mem(rcx, rsi, 1))\ + vmovupd(ymm(R3), mem(rcx, rsi, 2))\ + vmovupd(ymm(R4), mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 4x3 tile*/ \ + vunpcklpd(ymm(R6), ymm(R5), ymm0)\ + vunpckhpd(ymm(R6), ymm(R5), ymm1)\ + vunpcklpd(ymm(R8), ymm(R7), ymm2)\ + vunpckhpd(ymm(R8), ymm(R7), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R5))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R6))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R7))\ +\ + vmovupd(ymm(R5), mem(rcx ))\ + vmovupd(ymm(R6), mem(rcx, rsi, 1))\ + vmovupd(ymm(R7), mem(rcx, rsi, 2)) + +//3, 5, 7, 4, 6, 8 +#define C_TRANSPOSE_3x7_TILE(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vbroadcastsd(mem(rbx), ymm11)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R1))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R2))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R3))\ + vfmadd231pd(mem(rcx, rax, 1), xmm11, xmm10)\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vfmadd231sd(mem(rdx, rax, 1), xmm11, xmm15)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpckhpd(ymm(R5), ymm(R4), ymm1)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vunpckhpd(ymm11, ymm(R6), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R5))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R6))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(0x1), ymm(R4), xmm12)\ + vextractf128(imm(0x1), ymm(R5), xmm13)\ + vextractf128(imm(0x1), ymm(R6), xmm14)\ +\ + vfmadd231pd(mem(rcx ), xmm11, xmm(R4))\ + vfmadd231pd(mem(rcx, rsi, 1), xmm11, xmm(R5))\ + vfmadd231pd(mem(rcx, rsi, 2), xmm11, xmm(R6))\ + vmovupd(xmm(R4), mem(rcx ))\ + vmovupd(xmm(R5), mem(rcx, rsi, 1))\ + vmovupd(xmm(R6), mem(rcx, rsi, 2))\ +\ + vfmadd231sd(mem(rdx ), xmm11, xmm12)\ + vfmadd231sd(mem(rdx, rsi, 1), xmm11, xmm13)\ + vfmadd231sd(mem(rdx, rsi, 2), xmm11, xmm14)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +#define C_TRANSPOSE_3x7_TILE_BZ(R1, R2, R3, R4, R5, R6) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vunpcklpd(ymm10, ymm(R3), ymm2)\ + vunpckhpd(ymm10, ymm(R3), ymm15)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R1))\ + vinsertf128(imm(0x1), xmm15, ymm1, ymm(R2))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R3))\ + vperm2f128(imm(0x31), ymm15, ymm1, ymm10)\ +\ + /*Transposing 1x4 tile*/ \ + vextractf128(imm(0x1), ymm(R1), xmm12)\ + vextractf128(imm(0x1), ymm(R2), xmm13)\ + vextractf128(imm(0x1), ymm(R3), xmm14)\ + vextractf128(imm(0x1), ymm10, xmm15)\ +\ + vmovupd(xmm(R1), mem(rcx ))\ + vmovupd(xmm(R2), mem(rcx, rsi, 1))\ + vmovupd(xmm(R3), mem(rcx, rsi, 2))\ + vmovupd(xmm10, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2))\ + vmovsd(xmm15, mem(rdx, rax, 1))\ + \ + lea(mem(rdx, rsi, 4), rdx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R5), ymm(R4), ymm0)\ + vunpckhpd(ymm(R5), ymm(R4), ymm1)\ + vunpcklpd(ymm11, ymm(R6), ymm2)\ + vunpckhpd(ymm11, ymm(R6), ymm3)\ + vinsertf128(imm(0x1), xmm2, ymm0, ymm(R4))\ + vinsertf128(imm(0x1), xmm3, ymm1, ymm(R5))\ + vperm2f128(imm(0x31), ymm2, ymm0, ymm(R6))\ +\ + /*Transposing 1x3 tile*/ \ + vextractf128(imm(0x1), ymm(R4), xmm12)\ + vextractf128(imm(0x1), ymm(R5), xmm13)\ + vextractf128(imm(0x1), ymm(R6), xmm14)\ +\ + vmovupd(xmm(R4), mem(rcx ))\ + vmovupd(xmm(R5), mem(rcx, rsi, 1))\ + vmovupd(xmm(R6), mem(rcx, rsi, 2))\ +\ + vmovsd(xmm12, mem(rdx ))\ + vmovsd(xmm13, mem(rdx, rsi, 1))\ + vmovsd(xmm14, mem(rdx, rsi, 2)) + +//3, 5, 4, 6 +#define C_TRANSPOSE_2x7_TILE(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vbroadcastsd(mem(rbx), ymm3)\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vfmadd231pd(mem(rcx, rax, 1), xmm3, xmm7)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ + vunpckhpd(ymm(R4), ymm(R3), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vfmadd231pd(mem(rcx ), xmm3, xmm0)\ + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1)\ + vfmadd231pd(mem(rcx, rsi, 2), xmm3, xmm2)\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_2x7_TILE_BZ(R1, R2, R3, R4) \ + /*Transposing 2x4 tile*/ \ + vunpcklpd(ymm(R2), ymm(R1), ymm0)\ + vunpckhpd(ymm(R2), ymm(R1), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ + vextractf128(imm(0x1), ymm1, xmm7)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2))\ + vmovupd(xmm7, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + /*Transposing 2x3 tile*/ \ + vunpcklpd(ymm(R4), ymm(R3), ymm0)\ + vunpckhpd(ymm(R4), ymm(R3), ymm1)\ + vextractf128(imm(0x1), ymm0, xmm2)\ +\ + vmovupd(xmm0, mem(rcx ))\ + vmovupd(xmm1, mem(rcx, rsi, 1))\ + vmovupd(xmm2, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_1x7_TILE(R1, R2) \ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vmovhpd(mem(rcx, rax, 1), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vbroadcastsd(mem(rbx), ymm15)\ + vfmadd213pd(ymm(R1), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ +\ + vmovlpd(mem(rcx ), xmm0, xmm0)\ + vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0)\ + vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1)\ + vperm2f128(imm(0x20), ymm1, ymm0, ymm0)\ +\ + vfmadd213pd(ymm(R2), ymm15, ymm0)\ +\ + vextractf128(imm(1), ymm0, xmm1)\ + vmovlpd(xmm0, mem(rcx ))\ + vmovhpd(xmm0, mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + + +#define C_TRANSPOSE_1x7_TILE_BZ(R1, R2) \ + vextractf128(imm(1), ymm(R1), xmm1)\ + vmovlpd(xmm(R1), mem(rcx ))\ + vmovhpd(xmm(R1), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2))\ + vmovhpd(xmm1, mem(rcx, rax, 1))\ +\ + lea(mem(rcx, rsi, 4), rcx)\ + vextractf128(imm(1), ymm(R2), xmm1)\ + vmovlpd(xmm(R2), mem(rcx ))\ + vmovhpd(xmm(R2), mem(rcx, rsi, 1))\ + vmovlpd(xmm1, mem(rcx, rsi, 2)) + +static const int64_t mask_3[4] = {-1, -1, -1, 0}; + +void bli_dgemmsup_rv_haswell_asm_5x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 6*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 6*8)) // prefetch c + 4*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 4*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 4*8)) // prefetch c + 6*cs_c + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 3*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rax, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rax, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vfmadd231pd(mem(r8, 0*32), ymm1, ymm9) + vmaskmovpd(mem(r8, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm11) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm12) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rax, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 1*32)) + + vmovupd(ymm7, mem(rbx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rbx, 1*32)) + + vmovupd(ymm9, mem(r8, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(r8, 1*32)) + + vmovupd(ymm11, mem(rdx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rdx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_5x7_TILE(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------4 + + vmovupd(ymm11, mem(rcx, 0*32)) + vmaskmovpd(ymm12, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_5x7_TILE_BZ(3, 5, 7, 9, 11, 4, 6, 8, 10, 12) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm15", + "memory" + ) +} + + +void bli_dgemmsup_rv_haswell_asm_4x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 6*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + vbroadcastsd(mem(rax, r13, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rax, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rax, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm6) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm9) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm10) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rax, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rax, 1*32)) + + vmovupd(ymm7, mem(rdx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 1*32)) + + vmovupd(ymm9, mem(rbx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rbx, 1*32)) + //-----------------------4 + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_4x7_TILE(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------3 + vmovupd(ymm9, mem(rcx, 0*32)) + vmaskmovpd(ymm10, ymm15, mem(rcx, 1*32)) + //-----------------------4 + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_4x7_TILE_BZ(3, 5, 7, 9, 4, 6, 8, 10) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_3x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 6*8)) // prefetch c + 2*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 2*8)) // prefetch c + 6*cs_c + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm9) + vfmadd231pd(ymm1, ymm2, ymm10) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm11) + vfmadd231pd(ymm1, ymm2, ymm12) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm13) + vfmadd231pd(ymm1, ymm2, ymm14) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm9, ymm3, ymm3) + vaddpd(ymm10, ymm4, ymm4) + vaddpd(ymm11, ymm5, ymm5) + vaddpd(ymm12, ymm6, ymm6) + vaddpd(ymm13, ymm7, ymm7) + vaddpd(ymm14, ymm8, ymm8) + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm5) + vfmadd231pd(ymm1, ymm2, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm7) + vfmadd231pd(ymm1, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rbx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm2) + vfmadd231pd(ymm2, ymm1, ymm6) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm7) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm8) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rbx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rbx, 1*32)) + + vmovupd(ymm7, mem(rdx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rdx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_3x7_TILE(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------2 + vmovupd(ymm7, mem(rcx, 0*32)) + vmaskmovpd(ymm8, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_3x7_TILE_BZ(3, 5, 7, 4, 6, 8) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", + "ymm14", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_2x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 6*8)) // prefetch c + 1*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 3*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm14) + vfmadd231pd(ymm0, ymm14, ymm5) + vfmadd231pd(ymm1, ymm14, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + + vbroadcastsd(mem(rax, r8, 1), ymm14) + vfmadd231pd(ymm0, ymm14, ymm12) + vfmadd231pd(ymm1, ymm14, ymm13) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm14) + vfmadd231pd(ymm0, ymm14, ymm5) + vfmadd231pd(ymm1, ymm14, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + + vbroadcastsd(mem(rax, r8, 1), ymm14) + vfmadd231pd(ymm0, ymm14, ymm12) + vfmadd231pd(ymm1, ymm14, ymm13) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + vaddpd(ymm10, ymm3, ymm3) + vaddpd(ymm11, ymm4, ymm4) + vaddpd(ymm12, ymm5, ymm5) + vaddpd(ymm13, ymm6, ymm6) + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + vbroadcastsd(mem(rax, r8, 1), ymm14) + vfmadd231pd(ymm0, ymm14, ymm5) + vfmadd231pd(ymm1, ymm14, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vfmadd231pd(mem(rdx, 0*32), ymm1, ymm5) + vmaskmovpd(mem(rdx, 1*32), ymm15, ymm2) + vfmadd231pd(ymm2, ymm1, ymm6) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + vmovupd(ymm5, mem(rdx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rdx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_2x7_TILE(3, 5, 4, 6) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + add(rdi, rcx) + //-----------------------1 + + vmovupd(ymm5, mem(rcx, 0*32)) + vmaskmovpd(ymm6, ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_2x7_TILE_BZ(3, 5, 4, 6) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm5", "ymm6", "ymm8", "ymm10","ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_1x7 + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + +// Sets up the mask for loading relevant remainder elements in load direction +// int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. +// +// Low end High end +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 4 | ----> Source vector +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) +// |_____|_____|_____|_____| +// +// ________________________ +// | | | | | +// | 1 | 2 | 3 | 0 | ----> Destination vector +// |_____|_____|_____|_____| +// +// Since we have 7 elements to load, kernel will use one normal load +// that loads 4 elements into vector register and for remainder 3 elements, +// kernel is using mask_3 which is set to -1, -1, -1, 0 so that the +// 3 elements will be loaded and 4th element will be set to 0 in destination vector. + int64_t const *mask_vec = mask_3; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm15) //load + mov(var(a), rax) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), rcx) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(rcx, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 6*8)) // prefetch c + 0*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; + prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 0*8)) // prefetch c + 6*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + lea(mem(rax, r9, 8), rdx) // + lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 1 + prefetch(0, mem(rdx, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, r9, 1, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm8) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vfmadd231pd(ymm8, ymm10, ymm6) + vfmadd231pd(ymm9, ymm10, ymm7) + + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + +#if 1 + prefetch(0, mem(rdx, r9, 2, 4*8)) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + +#if 1 + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm8) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm9) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm10) + vfmadd231pd(ymm8, ymm10, ymm6) + vfmadd231pd(ymm9, ymm10, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + vaddpd(ymm6, ymm3, ymm3) + vaddpd(ymm7, ymm4, ymm4) + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + +#if 0 + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) +#endif + + //Loads 4 element + vmovupd(mem(rbx, 0*32), ymm0) + //Loads 3 elements as per mask_3 mask vector + vmaskmovpd(mem(rbx, 1*32), ymm15, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm3) + vfmadd231pd(ymm1, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm1) // load beta and duplicate + + vmulpd(ymm0, ymm3, ymm3) // scale by alpha + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm1, ymm3) + vmaskmovpd(mem(rcx, 1*32), ymm15, ymm0) + vfmadd231pd(ymm0, ymm1, ymm4) + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + C_TRANSPOSE_1x7_TILE(3, 4) + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm3, mem(rcx, 0*32)) + vmaskmovpd(ymm4, ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + C_TRANSPOSE_1x7_TILE_BZ(3, 4) + jmp(.DDONE) // jump to end. + + label(.DDONE) + vzeroupper() + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [n0] "m" (n0), + [rs_c] "m" (rs_c), + [mask_vec] "m" (mask_vec), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", + "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm12", "ymm15", "memory" + ) +} diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c index 2a04011f37..8f0981aadb 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -178,7 +178,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8 // Advance C and A pointers by the mrs and nrs we just // used, and decrement m_left. cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } + } } // Advance C and B pointers by the mrs and nrs we just used, and @@ -208,9 +208,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -225,7 +225,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -275,25 +275,25 @@ void bli_dgemmsup_rv_haswell_asm_6x8 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -304,14 +304,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -320,7 +320,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 1 #if 0 @@ -337,14 +337,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -352,14 +352,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -370,14 +370,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -385,7 +385,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 3 @@ -403,14 +403,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -418,50 +418,50 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -469,22 +469,22 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -497,24 +497,24 @@ void bli_dgemmsup_rv_haswell_asm_6x8 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -523,60 +523,60 @@ void bli_dgemmsup_rv_haswell_asm_6x8 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -663,51 +663,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -772,12 +772,12 @@ void bli_dgemmsup_rv_haswell_asm_6x8 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -845,15 +845,15 @@ void bli_dgemmsup_rv_haswell_asm_5x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a //lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -912,18 +912,18 @@ void bli_dgemmsup_rv_haswell_asm_5x8 lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 @@ -940,20 +940,20 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - + // ---------------------------------- iteration 1 #if 0 @@ -970,26 +970,26 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - + + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -1000,19 +1000,19 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - + // ---------------------------------- iteration 3 @@ -1030,37 +1030,37 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 @@ -1071,41 +1071,41 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1116,24 +1116,24 @@ void bli_dgemmsup_rv_haswell_asm_5x8 vmulpd(ymm0, ymm11, ymm11) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm0, ymm13, ymm13) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1142,52 +1142,37 @@ void bli_dgemmsup_rv_haswell_asm_5x8 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + label(.DROWSTORED) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rbx) // load address of c + 2*rs_c; + lea(mem(rbx, rdi, 1), r8) // load address of c + 3*rs_c; - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) - vmovupd(ymm12, mem(rcx, 0*32)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rax, 1*32), ymm3, ymm7) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rbx, 1*32), ymm3, ymm9) + vfmadd231pd(mem(r8, 0*32), ymm3, ymm10) + vfmadd231pd(mem(r8, 1*32), ymm3, ymm11) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm12) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm13) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(ymm7, mem(rax, 1*32)) + vmovupd(ymm8, mem(rbx, 0*32)) + vmovupd(ymm9, mem(rbx, 1*32)) + vmovupd(ymm10, mem(r8, 0*32)) + vmovupd(ymm11, mem(r8, 1*32)) + vmovupd(ymm12, mem(rdx, 0*32)) + vmovupd(ymm13, mem(rdx, 1*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) - vmovupd(ymm13, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. @@ -1272,46 +1257,46 @@ void bli_dgemmsup_rv_haswell_asm_5x8 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(ymm13, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1370,9 +1355,9 @@ void bli_dgemmsup_rv_haswell_asm_5x8 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) @@ -1442,9 +1427,9 @@ void bli_dgemmsup_rv_haswell_asm_4x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1459,7 +1444,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1501,31 +1486,31 @@ void bli_dgemmsup_rv_haswell_asm_4x8 prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c label(.DPOSTPFETCH) // done prefetching c - + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 4*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -1536,7 +1521,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1544,8 +1529,8 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - + + // ---------------------------------- iteration 1 #if 0 @@ -1562,7 +1547,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1570,10 +1555,10 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -1588,7 +1573,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1596,7 +1581,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + // ---------------------------------- iteration 3 @@ -1614,7 +1599,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1622,27 +1607,27 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -1658,7 +1643,7 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1666,22 +1651,22 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1690,72 +1675,62 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) vmulpd(ymm0, ymm11, ymm11) - - - - - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; - //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, 1*32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, 1*32)) - add(rdi, rcx) + label(.DROWSTORED) + lea(mem(rcx, rdi, 1), rax) // load address of c + 1*rs_c; + lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; + lea(mem(rdx, rdi, 1), rbx) // load address of c + 3*rs_c; - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vfmadd231pd(mem(rax, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rax, 1*32), ymm3, ymm7) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm9) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm10) + vfmadd231pd(mem(rbx, 1*32), ymm3, ymm11) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, 1*32)) - add(rdi, rcx) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + vmovupd(ymm6, mem(rax, 0*32)) + vmovupd(ymm7, mem(rax, 1*32)) + vmovupd(ymm8, mem(rdx, 0*32)) + vmovupd(ymm9, mem(rdx, 1*32)) + vmovupd(ymm10, mem(rbx, 0*32)) + vmovupd(ymm11, mem(rbx, 1*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) - vmovupd(ymm10, mem(rcx, 0*32)) - - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) - vmovupd(ymm11, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - + label(.DCOLSTORED) @@ -1810,19 +1785,19 @@ void bli_dgemmsup_rv_haswell_asm_4x8 jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) @@ -1838,8 +1813,8 @@ void bli_dgemmsup_rv_haswell_asm_4x8 vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(ymm11, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1880,12 +1855,12 @@ void bli_dgemmsup_rv_haswell_asm_4x8 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1952,9 +1927,9 @@ void bli_dgemmsup_rv_haswell_asm_3x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1969,7 +1944,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2010,27 +1985,27 @@ void bli_dgemmsup_rv_haswell_asm_3x8 prefetch(0, mem(rdx, rsi, 2, 2*8)) // prefetch c + 7*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 - + #if 1 prefetch(0, mem(rdx, 4*8)) #endif @@ -2045,13 +2020,13 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - - + + // ---------------------------------- iteration 1 #if 0 @@ -2064,19 +2039,19 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - + vfmadd231pd(ymm0, ymm2, ymm14) + vfmadd231pd(ymm1, ymm2, ymm15) + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -2091,12 +2066,12 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - + // ---------------------------------- iteration 3 @@ -2110,36 +2085,40 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(ymm0, ymm2, ymm10) + vfmadd231pd(ymm1, ymm2, ymm11) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm1, ymm3, ymm13) + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - - - + vfmadd231pd(ymm0, ymm2, ymm14) + vfmadd231pd(ymm1, ymm2, ymm15) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + vaddpd(ymm10, ymm4, ymm4) + vaddpd(ymm11, ymm5, ymm5) + vaddpd(ymm12, ymm6, ymm6) + vaddpd(ymm13, ymm7, ymm7) + vaddpd(ymm14, ymm8, ymm8) + vaddpd(ymm15, ymm9, ymm9) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -2155,91 +2134,82 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm7, ymm7) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm9, ymm9) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 2), rdx) // load address of c + 2*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, 1*32)) - add(rdi, rcx) + label(.DROWSTORED) + lea(mem(rcx, rdi, 1), rbx) // load address of c + 1*rs_c; - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, 1*32)) - add(rdi, rcx) - + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vfmadd231pd(mem(rbx, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rbx, 1*32), ymm3, ymm7) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm8) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm9) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) - vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + vmovupd(ymm6, mem(rbx, 0*32)) + vmovupd(ymm7, mem(rbx, 1*32)) + vmovupd(ymm8, mem(rdx, 0*32)) + vmovupd(ymm9, mem(rdx, 1*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - + label(.DCOLSTORED) @@ -2280,7 +2250,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) - + lea(mem(rdx, rsi, 4), rdx) // begin I/O on columns 4-7 @@ -2319,26 +2289,26 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vmovsd(xmm13, mem(rdx, rsi, 1)) vmovsd(xmm14, mem(rdx, rsi, 2)) vmovsd(xmm15, mem(rdx, rax, 1)) - + //lea(mem(rdx, rsi, 4), rdx) jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) @@ -2350,8 +2320,8 @@ void bli_dgemmsup_rv_haswell_asm_3x8 vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(ymm9, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2416,12 +2386,12 @@ void bli_dgemmsup_rv_haswell_asm_3x8 //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2450,7 +2420,7 @@ void bli_dgemmsup_rv_haswell_asm_3x8 "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", - "memory" + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -2488,9 +2458,9 @@ void bli_dgemmsup_rv_haswell_asm_2x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2505,7 +2475,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2545,27 +2515,27 @@ void bli_dgemmsup_rv_haswell_asm_2x8 prefetch(0, mem(rdx, rsi, 2, 1*8)) // prefetch c + 7*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 - + #if 1 prefetch(0, mem(rdx, 4*8)) #endif @@ -2581,29 +2551,29 @@ void bli_dgemmsup_rv_haswell_asm_2x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - - + + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 4*8)) #endif - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) + vmovupd(mem(rbx, 0*32), ymm8) + vmovupd(mem(rbx, 1*32), ymm9) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - + vfmadd231pd(ymm8, ymm2, ymm10) + vfmadd231pd(ymm9, ymm2, ymm11) + vfmadd231pd(ymm8, ymm3, ymm12) + vfmadd231pd(ymm9, ymm3, ymm13) + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -2619,7 +2589,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + // ---------------------------------- iteration 3 @@ -2627,38 +2597,40 @@ void bli_dgemmsup_rv_haswell_asm_2x8 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) + vmovupd(mem(rbx, 0*32), ymm8) + vmovupd(mem(rbx, 1*32), ymm9) add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - - - + vfmadd231pd(ymm8, ymm2, ymm10) + vfmadd231pd(ymm9, ymm2, ymm11) + vfmadd231pd(ymm8, ymm3, ymm12) + vfmadd231pd(ymm9, ymm3, ymm13) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + vaddpd(ymm10, ymm4, ymm4) + vaddpd(ymm11, ymm5, ymm5) + vaddpd(ymm12, ymm6, ymm6) + vaddpd(ymm13, ymm7, ymm7) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -2675,44 +2647,44 @@ void bli_dgemmsup_rv_haswell_asm_2x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm7, ymm7) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2721,28 +2693,24 @@ void bli_dgemmsup_rv_haswell_asm_2x8 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) - vmovupd(ymm4, mem(rcx, 0*32)) + lea(mem(rcx, rdi, 1), rdx) // load address of c + 1*rs_c; + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx, 1*32)) - add(rdi, rcx) + vfmadd231pd(mem(rdx, 0*32), ymm3, ymm6) + vfmadd231pd(mem(rdx, 1*32), ymm3, ymm7) + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + vmovupd(ymm6, mem(rdx, 0*32)) + vmovupd(ymm7, mem(rdx, 1*32)) - vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) - vmovupd(ymm6, mem(rcx, 0*32)) - vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx, 1*32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. @@ -2787,19 +2755,19 @@ void bli_dgemmsup_rv_haswell_asm_2x8 jmp(.DDONE) // jump to end. - - + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) @@ -2807,8 +2775,8 @@ void bli_dgemmsup_rv_haswell_asm_2x8 vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(ymm7, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2841,12 +2809,12 @@ void bli_dgemmsup_rv_haswell_asm_2x8 //lea(mem(rcx, rsi, 4), rcx) - - - + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -2874,7 +2842,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", - "memory" + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "memory" ) } @@ -2912,9 +2880,9 @@ void bli_dgemmsup_rv_haswell_asm_1x8 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2929,7 +2897,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2968,27 +2936,27 @@ void bli_dgemmsup_rv_haswell_asm_1x8 prefetch(0, mem(rdx, rsi, 2, 0*8)) // prefetch c + 7*cs_c label(.DPOSTPFETCH) // done prefetching c - + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 - + #if 1 prefetch(0, mem(rdx, 4*8)) #endif @@ -2999,28 +2967,29 @@ void bli_dgemmsup_rv_haswell_asm_1x8 vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax ), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - - + + // ---------------------------------- iteration 1 #if 0 prefetch(0, mem(rdx, r9, 1, 4*8)) #endif - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) + vmovupd(mem(rbx, 0*32), ymm8) + vmovupd(mem(rbx, 1*32), ymm9) add(r10, rbx) // b += rs_b; - - vbroadcastsd(mem(rax ), ymm2) - add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - + + vfmadd231pd(ymm8, ymm3, ymm6) + vfmadd231pd(ymm9, ymm3, ymm7) + // ---------------------------------- iteration 2 - + #if 1 prefetch(0, mem(rdx, r9, 2, 4*8)) #endif @@ -3031,9 +3000,11 @@ void bli_dgemmsup_rv_haswell_asm_1x8 vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax ), ymm3) + add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - + // ---------------------------------- iteration 3 @@ -3041,35 +3012,34 @@ void bli_dgemmsup_rv_haswell_asm_1x8 lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; #endif - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) + vmovupd(mem(rbx, 0*32), ymm8) + vmovupd(mem(rbx, 1*32), ymm9) add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - - - + + vfmadd231pd(ymm8, ymm3, ymm6) + vfmadd231pd(ymm9, ymm3, ymm7) + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + vaddpd(ymm6, ymm4, ymm4) + vaddpd(ymm7, ymm5, ymm5) + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -3078,18 +3048,18 @@ void bli_dgemmsup_rv_haswell_asm_1x8 vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) @@ -3098,27 +3068,27 @@ void bli_dgemmsup_rv_haswell_asm_1x8 mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; //lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -3126,20 +3096,20 @@ void bli_dgemmsup_rv_haswell_asm_1x8 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) vmovupd(ymm5, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -3160,7 +3130,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8 vmovhpd(xmm0, mem(rcx, rsi, 1)) vmovlpd(xmm1, mem(rcx, rsi, 2)) vmovhpd(xmm1, mem(rcx, rax, 1)) - + lea(mem(rcx, rsi, 4), rcx) // begin I/O on columns 4-7 @@ -3183,26 +3153,26 @@ void bli_dgemmsup_rv_haswell_asm_1x8 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -3230,14 +3200,14 @@ void bli_dgemmsup_rv_haswell_asm_1x8 vmovhpd(xmm1, mem(rcx, rax, 1)) //lea(mem(rcx, rsi, 4), rcx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -3265,7 +3235,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8 "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", - "memory" + "ymm6", "ymm7", "ymm8", "ymm9", "memory" ) } diff --git a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c index 8aa5f94f76..2a518c794a 100644 --- a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -695,7 +695,9 @@ void bli_dgemmsup_rd_haswell_asm_6x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1188,7 +1190,8 @@ void bli_dgemmsup_rd_haswell_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -1586,7 +1589,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -2117,7 +2120,9 @@ void bli_dgemmsup_rd_haswell_asm_6x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -2564,7 +2569,8 @@ void bli_dgemmsup_rd_haswell_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -2927,7 +2933,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -3480,7 +3486,8 @@ void bli_dgemmsup_rd_haswell_asm_6x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) consider_edge_cases: @@ -3914,7 +3921,8 @@ void bli_dgemmsup_rd_haswell_asm_3x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -4270,7 +4278,7 @@ void bli_dgemmsup_rd_haswell_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -4585,7 +4593,7 @@ void bli_dgemmsup_rd_haswell_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); diff --git a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c index 4e37f6d1b6..d8e8fb148a 100644 --- a/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c +++ b/kernels/haswell/3/sup/d6x8/old/bli_gemmsup_rv_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -809,7 +809,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -1437,7 +1439,8 @@ void bli_dgemmsup_rv_haswell_asm_5x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -1927,7 +1930,8 @@ void bli_dgemmsup_rv_haswell_asm_4x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -2444,7 +2448,8 @@ void bli_dgemmsup_rv_haswell_asm_3x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -2848,7 +2853,7 @@ void bli_dgemmsup_rv_haswell_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -3216,7 +3221,7 @@ void bli_dgemmsup_rv_haswell_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -3823,7 +3828,9 @@ void bli_dgemmsup_rv_haswell_asm_6x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -4449,7 +4456,8 @@ void bli_dgemmsup_rv_haswell_asm_5x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -4933,7 +4941,8 @@ void bli_dgemmsup_rv_haswell_asm_4x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -5447,7 +5456,8 @@ void bli_dgemmsup_rv_haswell_asm_3x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -5863,7 +5873,8 @@ void bli_dgemmsup_rv_haswell_asm_2x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -6233,7 +6244,7 @@ void bli_dgemmsup_rv_haswell_asm_1x6 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -6710,7 +6721,8 @@ void bli_dgemmsup_rv_haswell_asm_6x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "ymm14", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -7205,7 +7217,8 @@ void bli_dgemmsup_rv_haswell_asm_5x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -7606,7 +7619,7 @@ void bli_dgemmsup_rv_haswell_asm_4x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -8014,7 +8027,7 @@ void bli_dgemmsup_rv_haswell_asm_3x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -8357,7 +8370,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -8677,7 +8690,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -9132,7 +9145,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -9582,7 +9595,7 @@ void bli_dgemmsup_rv_haswell_asm_5x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -9971,7 +9984,7 @@ void bli_dgemmsup_rv_haswell_asm_4x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -10377,7 +10390,7 @@ void bli_dgemmsup_rv_haswell_asm_3x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -10707,7 +10720,7 @@ void bli_dgemmsup_rv_haswell_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -11014,7 +11027,7 @@ void bli_dgemmsup_rv_haswell_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "memory" ) AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c index c5addd9cf2..b48bf3cab6 100644 --- a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -646,7 +646,9 @@ void bli_dgemmsup_rd_haswell_asm_6x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1130,7 +1132,9 @@ void bli_dgemmsup_rd_haswell_asm_3x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1571,7 +1575,8 @@ void bli_dgemmsup_rd_haswell_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -1960,7 +1965,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -2454,7 +2459,9 @@ void bli_dgemmsup_rd_haswell_asm_6x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -2910,7 +2917,9 @@ void bli_dgemmsup_rd_haswell_asm_3x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3314,7 +3323,8 @@ void bli_dgemmsup_rd_haswell_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -3675,7 +3685,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -4184,7 +4194,8 @@ void bli_dgemmsup_rd_haswell_asm_6x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -4576,7 +4587,8 @@ void bli_dgemmsup_rd_haswell_asm_3x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "memory" ) } @@ -4929,7 +4941,7 @@ void bli_dgemmsup_rd_haswell_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -5243,7 +5255,7 @@ void bli_dgemmsup_rd_haswell_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "memory" ) } diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c index 55ae6d0f91..def75c5e47 100644 --- a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -695,7 +695,9 @@ void bli_dgemmsup_rd_haswell_asm_6x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1242,7 +1244,9 @@ void bli_dgemmsup_rd_haswell_asm_3x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1695,7 +1699,8 @@ void bli_dgemmsup_rd_haswell_asm_2x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -2090,7 +2095,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -2620,7 +2625,9 @@ void bli_dgemmsup_rd_haswell_asm_6x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -3120,7 +3127,9 @@ void bli_dgemmsup_rd_haswell_asm_3x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3527,7 +3536,8 @@ void bli_dgemmsup_rd_haswell_asm_2x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -3887,7 +3897,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -4437,7 +4447,8 @@ void bli_dgemmsup_rd_haswell_asm_6x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) consider_edge_cases: @@ -4870,7 +4881,8 @@ void bli_dgemmsup_rd_haswell_asm_3x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "memory" ) } @@ -5224,7 +5236,7 @@ void bli_dgemmsup_rd_haswell_asm_2x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -5537,7 +5549,7 @@ void bli_dgemmsup_rd_haswell_asm_1x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "memory" ) } diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji index c1cb372142..df9e27d5af 100644 --- a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.newji @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij index fd1c2ae657..f6b11e3971 100644 --- a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8m.c.worksij @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c index a23764f8d4..d738d46dfb 100644 --- a/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c +++ b/kernels/haswell/3/sup/old/bli_gemmsup_rd_haswell_asm_d6x8n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -713,7 +713,9 @@ void bli_dgemmsup_rd_haswell_asm_6x8n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1312,7 +1314,9 @@ void bli_dgemmsup_rd_haswell_asm_3x8n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1857,7 +1861,8 @@ void bli_dgemmsup_rd_haswell_asm_2x8n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) consider_edge_cases: @@ -2347,7 +2352,7 @@ void bli_dgemmsup_rd_haswell_asm_1x8n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) consider_edge_cases: @@ -2934,7 +2939,9 @@ void bli_dgemmsup_rd_haswell_asm_6x4n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -3444,7 +3451,9 @@ void bli_dgemmsup_rd_haswell_asm_3x4n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3860,7 +3869,8 @@ void bli_dgemmsup_rd_haswell_asm_2x4n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -4229,7 +4239,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -4751,7 +4761,8 @@ void bli_dgemmsup_rd_haswell_asm_6x2n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -5145,7 +5156,8 @@ void bli_dgemmsup_rd_haswell_asm_3x2n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "memory" ) } @@ -5508,7 +5520,7 @@ void bli_dgemmsup_rd_haswell_asm_2x2n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -5830,7 +5842,7 @@ void bli_dgemmsup_rd_haswell_asm_1x2n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "memory" ) } diff --git a/kernels/haswell/3/sup/s6x16/CMakeLists.txt b/kernels/haswell/3/sup/s6x16/CMakeLists.txt deleted file mode 100644 index 0be5cd76e8..0000000000 --- a/kernels/haswell/3/sup/s6x16/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_r_haswell_ref_sMx1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx12.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx16.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_haswell_asm_sMx8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx12.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx16.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx6.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_haswell_asm_sMx8.c - ) - - diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c index dad5458b9a..00d6581db8 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_r_haswell_ref_sMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c index fe6d124d32..99cb45be52 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c index b7b0b46a1b..7313e0d7e3 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx12.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c index 9819671c7d..fe33964b4f 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c index 190eb9d1d7..80c13b65dd 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c index d167bc08fb..cfb9ac180b 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c index 498002da90..8c20271da8 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rd_haswell_asm_sMx8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c index dd2c392e9c..8d3539667c 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx12.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c index f6443e8b50..6b9ff603ca 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 3d90e6e4f3..9ce3c08ca1 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c index 512fd60525..201731759d 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c index 1d80111ea8..5c56c9ef59 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx6.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c index 43210cdc5a..f0af711645 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/CMakeLists.txt b/kernels/haswell/CMakeLists.txt deleted file mode 100644 index 2a161a1685..0000000000 --- a/kernels/haswell/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## - -add_subdirectory(3) -add_subdirectory(1m) - diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index d841d715f3..a7900af519 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -219,6 +219,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x8 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x8 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x8 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x7 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x7 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x6 ) @@ -226,6 +232,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x6 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x6 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x5 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x5 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x4 ) @@ -233,6 +245,12 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x4 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x4 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x3 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x3 ) + GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x2 ) @@ -240,6 +258,11 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x2 ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x2 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_5x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_4x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_3x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_2x1 ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_1x1 ) // gemmsup_rv (mkernel in m dim) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m ) diff --git a/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c b/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c index 91fe1989f0..cd4c3aef61 100644 --- a/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c +++ b/kernels/knl/1m/bli_dpackm_knl_asm_24x8.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -304,7 +305,8 @@ void bli_dpackm_knl_asm_8xk "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "memory" + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "k0", "k1", + "ymm0", "ymm3", "memory" ) } @@ -608,7 +610,8 @@ void bli_dpackm_knl_asm_24xk "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "rax", "rbx", "rcx", "rdi", "rsi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "memory" + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "k0", "k1", "k2", "k3", "ymm0", "ymm1", "ymm2", "ymm3", "memory" ) } diff --git a/kernels/knl/1m/bli_spackm_knl_asm_24x16.c b/kernels/knl/1m/bli_spackm_knl_asm_24x16.c index 8c4bdfe6be..571e166cd4 100644 --- a/kernels/knl/1m/bli_spackm_knl_asm_24x16.c +++ b/kernels/knl/1m/bli_spackm_knl_asm_24x16.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -322,7 +323,11 @@ void bli_spackm_knl_asm_16xk "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "memory" + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "k0", "k1", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", + "xmm12", "xmm13", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm15", "memory" ) } @@ -625,7 +630,11 @@ void bli_spackm_knl_asm_24xk "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "memory" + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "k0", "k1", + "k2", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm12", "xmm13", "xmm15", "ymm0", "ymm1", "ymm2", + "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", + "ymm11", "ymm12", "ymm13", "ymm15", "memory" ) } diff --git a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c index b794e7c059..82e5a25435 100644 --- a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c +++ b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -698,7 +699,8 @@ void bli_dgemm_knl_asm_24x8 "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", - "zmm30", "zmm31", "memory" + "zmm30", "zmm31", "k0", "k1", "k2", "xmm1", "ymm2", "ymm3", + "ymm5", "memory" ) #ifdef LOOPMON diff --git a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c index 6d485b5308..b1ed2abf74 100644 --- a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c +++ b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -695,7 +696,7 @@ void bli_sgemm_knl_asm_24x16 "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", - "zmm30", "zmm31", "memory" + "zmm30", "zmm31", "k0", "k1", "k2", "xmm1", "ymm3", "ymm5", "memory" ) #ifdef LOOPMON diff --git a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c index 187182a095..ec09f8e380 100644 --- a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c +++ b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c @@ -95,8 +95,8 @@ void bli_dgemm_power9_asm_12x6 " \n\t" DPREFETCH " \n\t" - "cmpwi %%r0, %%r11, 0 \n\t" // if k_iter == 0, - "beq %%r0, DCONSIDERKLEFT \n\t" // then jmp to k_left + "cmpwi %%r11, 0 \n\t" // if k_iter == 0, + "beq DCONSIDERKLEFT \n\t" // then jmp to k_left "mtctr %%r11 \n\t" // else, do k_iter loop " \n\t" "DLOOPKITER: \n\t" // k_iter loop @@ -107,8 +107,8 @@ void bli_dgemm_power9_asm_12x6 " \n\t" "DCONSIDERKLEFT: \n\t" " \n\t" - "cmpwi %%r0, %%r12, 0 \n\t" // if k_left == 0, - "beq %%r0, DPOSTACCUM \n\t" // then jmp to post accum + "cmpwi %%r12, 0 \n\t" // if k_left == 0, + "beq DPOSTACCUM \n\t" // then jmp to post accum "mtctr %%r12 \n\t" // else, do k_left loop " \n\t" "DLOOPKLEFT: \n\t" // k_left loop @@ -121,10 +121,10 @@ void bli_dgemm_power9_asm_12x6 " \n\t" DSCALE_ALPHA " \n\t" - "cmpdi %%r0, %%r26, 0 \n\t" // if beta == 0, - "beq %%r0, DBETAZERO \n\t" // then jmp to BZ + "cmpdi %%r26, 0 \n\t" // if beta == 0, + "beq DBETAZERO \n\t" // then jmp to BZ " \n\t" - "cmpwi %%r0, %%r9, 8 \n\t" // if rs_c == 8 + "cmpwi %%r9, 8 \n\t" // if rs_c == 8 "beq DCOLSTOREDBNZ \n\t" // then jmp to col store " \n\t" "DGENSTOREDBNZ: \n\t" // BNZ gen stored case @@ -143,7 +143,7 @@ void bli_dgemm_power9_asm_12x6 " \n\t" "DBETAZERO: \n\t" // BZ case " \n\t" - "cmpwi %%r0, %%r9, 8 \n\t" // if rs_c == 8, + "cmpwi %%r9, 8 \n\t" // if rs_c == 8, "beq DCOLSTORED \n\t" // C is col stored " \n\t" "DGENSTORED: \n\t" // BZ gen stored case diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c index a56ef16e5e..63ac331a60 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1022,7 +1023,9 @@ void bli_sgemm_sandybridge_asm_8x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1697,7 +1700,9 @@ void bli_dgemm_sandybridge_asm_8x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -2658,7 +2663,9 @@ void bli_cgemm_sandybridge_asm_8x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3508,7 +3515,9 @@ void bli_zgemm_sandybridge_asm_4x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } diff --git a/kernels/skx/3/CMakeLists.txt b/kernels/skx/3/CMakeLists.txt deleted file mode 100644 index e4125f1b60..0000000000 --- a/kernels/skx/3/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -##Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(skx_3 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_skx_asm_16x14.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_sgemm_skx_asm_32x12_l2.c - ) -target_compile_options(skx_3 PRIVATE /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(skx_3 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c index 3a20cd8618..5735a5911a 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -539,7 +540,11 @@ void bli_dgemm_skx_asm_16x12_l2( [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "r13", "r14", "r15", "k0", "k1", "k2", "xmm1", "xmm7", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", + "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", + "ymm29", "ymm30", "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index c0ada1eb66..038920b834 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -454,7 +454,12 @@ void bli_dgemm_skx_asm_16x14( [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "r13", "r14", "r15", "k0", "k1", "k2", "k3", "k4", "xmm1", + "xmm2", "ymm2", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "ymm16", + "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", + "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", + "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index 40af496140..572045832d 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -564,10 +565,14 @@ void bli_sgemm_skx_asm_32x12_l2( [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", - "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", - "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", - "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", - "zmm30", "zmm31", "memory" + "r13", "r14", "r15", "k0", "k1", "k2", "k3", "k4", "xmm1", "xmm7", + "ymm1", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "ymm16", "ymm17", "ymm18", + "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", + "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31", "zmm0", "zmm1", + "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", + "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", + "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", + "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) } diff --git a/kernels/skx/CMakeLists.txt b/kernels/skx/CMakeLists.txt deleted file mode 100644 index a9ba638da8..0000000000 --- a/kernels/skx/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -##Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.## -remove_definitions(/arch:AVX2) - -add_subdirectory(3) \ No newline at end of file diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt deleted file mode 100644 index 87db4ac1c7..0000000000 --- a/kernels/zen/1/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen_1 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int10.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int10.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_copyv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotv_zen_int10.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scalv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scalv_zen_int10.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_setv_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_swapv_zen_int8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_norm2_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scal2v_zen_int.c - ) -target_compile_options(zen_1 PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen_1 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index 3adb524799..5c9e7af81b 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_axpbyv_zen_int.c b/kernels/zen/1/bli_axpbyv_zen_int.c index c92d44ad3e..23748ab992 100644 --- a/kernels/zen/1/bli_axpbyv_zen_int.c +++ b/kernels/zen/1/bli_axpbyv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -734,416 +734,593 @@ void bli_zaxpbyv_zen_int ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - const dim_t n_elem_per_reg = 4; // number of elements per register - dim_t i; // iterator + dim_t i = 0; // iterator + // Local pointers to x and y vectors double* restrict x0; double* restrict y0; + // Variables to store real and imaginary components of alpha and beta double alphaR, alphaI, betaR, betaI; - __m256d alphaRv; - __m256d alphaIv; - __m256d betaRv; - __m256d betaIv; - __m256d xv[4]; - __m256d yv[4]; - __m256d iv[4]; // intermediate registers - + // Local variable to store the conjugate type conj_t conjx_use = conjx; - - /* if the vector dimension is zero, or if alpha & beta are zero, - return early. */ - if ( bli_zero_dim1( n ) || - ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + + /* If the vector dimension is zero, return early. */ + if ( bli_zero_dim1( n ) ) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return; } - // initialize local pointers - x0 = ( double* ) x; - y0 = ( double* ) y; + // Initializing the local pointers + x0 = ( double* ) x; + y0 = ( double* ) y; alphaR = alpha->real; alphaI = alpha->imag; betaR = beta->real; betaI = beta->imag; - if ( incx == 1 && incy == 1 ) - { - //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- - // y = beta*y + alpha*x - // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) - // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI - // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + - // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + // Vectors to store real and imaginary components of beta + __m256d betaRv, betaIv; - // SIMD Algorithm BLIS_NO_CONJUGATE - // yv = yR1 yI1 yR2 yI2 - // yv' = yI1 yR1 yI2 yR2 - // xv = xR1 xI1 xR2 xI2 - // xv' = xI1 xR1 xI2 xR2 - // arv = aR aR aR aR - // aiv = -aI aI -aI aI - // brv = bR bR bR bR - // biv = -bI bI -bI bI - // - // step 1: iv = brv * iv - // step 2: shuffle yv -> yv' - // step 3: FMA yv = biv * yv' + iv - // step 4: iv = arv * xv - // step 5: shuffle xv -> xv' - // step 6: FMA yv = aiv * xv' + iv + // Broadcasting real and imaginary components of beta onto the registers + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_broadcast_sd( &betaI ); - //---------- Scalar algorithm BLIS_CONJUGATE ------------- - // y = beta*y + alpha*conj(x) - // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) - // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI - // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + - // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + // Initializing a variable to classify the type of the computation + bool is_alpha_zero = bli_zeq0( *alpha ); - // SIMD Algorithm BLIS_CONJUGATE - // yv = yR1 yI1 yR2 yI2 - // yv' = yI1 yR1 yI2 yR2 - // xv = xR1 xI1 xR2 xI2 - // xv' = xI1 xR1 xI2 xR2 - // arv = aR -aR aR -aR - // aiv = aI aI aI aI - // brv = bR bR bR bR - // biv = -bI bI -bI bI - // - // step 1: iv = brv * iv - // step 2: shuffle yv -> yv' - // step 3: FMA yv = biv * yv' + iv - // step 4: iv = arv * xv - // step 5: shuffle xv -> xv' - // step 6: FMA yv = aiv * xv' + iv + // In case of unit strides for x and y vectors + if ( incx == 1 && incy == 1 ) + { + // Number of double precision elements in a YMM register + const dim_t n_elem_per_reg = 4; - // broadcast alpha & beta to all elements of respective vector registers - if ( !bli_is_conj( conjx ) ) - { - // alphaRv = aR aR aR aR - // alphaIv = -aI aI -aI aI - // betaRv = bR bR bR bR - // betaIv = -bI bI -bI bI - alphaRv = _mm256_broadcast_sd( &alphaR ); - alphaIv = _mm256_set_pd( alphaI, -alphaI, alphaI, -alphaI ); - betaRv = _mm256_broadcast_sd( &betaR ); - betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); - } - else - { - // alphaRv = aR -aR aR -aR - // alphaIv = aI aI aI aI - // betaRv = bR bR bR bR - // betaIv = -bI bI -bI bI - alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); - alphaIv = _mm256_broadcast_sd( &alphaI ); - betaRv = _mm256_broadcast_sd( &betaR ); - betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); - } + // Scratch registers + __m256d xv[4]; + __m256d yv[4]; + __m256d iv[4]; - // Processing 8 elements per loop, 8 FMAs - for ( i = 0; ( i + 7 ) < n; i += 8 ) + // In case of alpha being 0, we just need to scale y by beta + if( is_alpha_zero ) { - // xv = xR1 xI1 xR2 xI2 - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_pd( betaRv, yv[0] ); - iv[1] = _mm256_mul_pd( betaRv, yv[1] ); - iv[2] = _mm256_mul_pd( betaRv, yv[2] ); - iv[3] = _mm256_mul_pd( betaRv, yv[3] ); - - // yv' = yI1 yR1 yI2 yR2 - yv[0] = _mm256_permute_pd( yv[0], 5); - yv[1] = _mm256_permute_pd( yv[1], 5); - yv[2] = _mm256_permute_pd( yv[2], 5); - yv[3] = _mm256_permute_pd( yv[3], 5); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); - yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); - yv[3] = _mm256_fmadd_pd( betaIv, yv[3], iv[3] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); - iv[3] = _mm256_mul_pd( alphaRv, xv[3] ); - - // xv' = xI1 xR1 xI2 xR2 - xv[0] = _mm256_permute_pd( xv[0], 5); - xv[1] = _mm256_permute_pd( xv[1], 5); - xv[2] = _mm256_permute_pd( xv[2], 5); - xv[3] = _mm256_permute_pd( xv[3], 5); - - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] ); - iv[2] = _mm256_fmadd_pd( alphaIv, xv[2], iv[2] ); - iv[3] = _mm256_fmadd_pd( alphaIv, xv[3], iv[3] ); + // Processing 8 elements per loop, 8 FMAs + for ( i = 0; ( i + 7 ) < n; i += 8 ) + { + // Load the y vector, 8 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3 * n_elem_per_reg ); + + // Permute the loaded vectors for the required compute + // xv = yI1 yR1 yI2 yR2 + xv[0] = _mm256_permute_pd( yv[0], 5 ); + xv[1] = _mm256_permute_pd( yv[1], 5 ); + xv[2] = _mm256_permute_pd( yv[2], 5 ); + xv[3] = _mm256_permute_pd( yv[3], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_mul_pd( betaIv, xv[0] ); + iv[1] = _mm256_mul_pd( betaIv, xv[1] ); + iv[2] = _mm256_mul_pd( betaIv, xv[2] ); + iv[3] = _mm256_mul_pd( betaIv, xv[3] ); + + // Using fmaddsub to scale with real component of beta and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); + yv[3] = _mm256_fmaddsub_pd( betaRv, yv[3], iv[3] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + _mm256_storeu_pd( ( y0 + 3 * n_elem_per_reg ), yv[3] ); + + // Adjusting the pointers for the next iteration + y0 += 4 * n_elem_per_reg; + x0 += 4 * n_elem_per_reg; + } - yv[0] = _mm256_add_pd( yv[0], iv[0] ); - yv[1] = _mm256_add_pd( yv[1], iv[1] ); - yv[2] = _mm256_add_pd( yv[2], iv[2] ); - yv[3] = _mm256_add_pd( yv[3], iv[3] ); + // Processing 6 elements per loop, 6 FMAs + for ( ; ( i + 5 ) < n; i += 6 ) + { + // Load the y vector, 6 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); + + // Permute the loaded vectors for the required compute + // xv = yI1 yR1 yI2 yR2 + xv[0] = _mm256_permute_pd( yv[0], 5 ); + xv[1] = _mm256_permute_pd( yv[1], 5 ); + xv[2] = _mm256_permute_pd( yv[2], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_mul_pd( betaIv, xv[0] ); + iv[1] = _mm256_mul_pd( betaIv, xv[1] ); + iv[2] = _mm256_mul_pd( betaIv, xv[2] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + + // Adjusting the pointers for the next iteration + y0 += 3 * n_elem_per_reg; + x0 += 3 * n_elem_per_reg; + } - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); - _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3] ); + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Load the y vector, 4 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + + // Permute the loaded vectors for the required compute + // xv = yI1 yR1 yI2 yR2 + xv[0] = _mm256_permute_pd( yv[0], 5 ); + xv[1] = _mm256_permute_pd( yv[1], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = yI1.bI, yR1.bI, yI2.bI, yR2.bI + iv[0] = _mm256_mul_pd( betaIv, xv[0] ); + iv[1] = _mm256_mul_pd( betaIv, xv[1] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + + // Adjusting the pointers for the next iteration + y0 += 2 * n_elem_per_reg; + x0 += 2 * n_elem_per_reg; + } - y0 += 4*n_elem_per_reg; - x0 += 4*n_elem_per_reg; + // Processing 2 elements per loop, 3 FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // Load the y vector, 2 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + + // Permute the loaded vectors for the required compute + // xv = yI1 yR1 yI2 yR2 + xv[0] = _mm256_permute_pd( yv[0], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_mul_pd( betaIv, xv[0] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + + // Adjusting the pointers for the next iteration + y0 += 1 * n_elem_per_reg; + x0 += 1 * n_elem_per_reg; + } } - // Processing 6 elements per loop, 6 FMAs - for ( ; ( i + 5 ) < n; i += 6 ) + else { - // xv = xR1 xI1 xR2 xI2 - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_pd( betaRv, yv[0] ); - iv[1] = _mm256_mul_pd( betaRv, yv[1] ); - iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + // Scratch registers for storing real and imaginary components of alpha + __m256d alphaRv, alphaIv; - // yv' = yI1 yR1 yI2 yR2 - yv[0] = _mm256_permute_pd( yv[0], 5); - yv[1] = _mm256_permute_pd( yv[1], 5); - yv[2] = _mm256_permute_pd( yv[2], 5); + iv[0] = _mm256_setzero_pd(); - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); - yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaIv = _mm256_broadcast_sd( &alphaI ); - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + // The changes on alphaRv and alphaIv are as follows : + // If conjugate is required: + // alphaRv = aR -aR aR -aR + // Else : + // alphaIv = -aI aI -aI aI + if( bli_is_conj( conjx_use ) ) + { + alphaRv = _mm256_fmsubadd_pd( iv[0], iv[0], alphaRv ); + } + else + { + alphaIv = _mm256_addsub_pd( iv[0], alphaIv ); + } - // xv' = xI1 xR1 xI2 xR2 - xv[0] = _mm256_permute_pd( xv[0], 5); - xv[1] = _mm256_permute_pd( xv[1], 5); - xv[2] = _mm256_permute_pd( xv[2], 5); + // Processing 8 elements per loop, 8 FMAs + for ( i = 0; ( i + 7 ) < n; i += 8 ) + { + // Load the y vector, 6 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3 * n_elem_per_reg ); + + // Load the x vector, 6 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_pd( yv[0], 5 ); + iv[1] = _mm256_permute_pd( yv[1], 5 ); + iv[2] = _mm256_permute_pd( yv[2], 5 ); + iv[3] = _mm256_permute_pd( yv[3], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); + iv[2] = _mm256_mul_pd( betaIv, iv[2] ); + iv[3] = _mm256_mul_pd( betaIv, iv[3] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); + yv[3] = _mm256_fmaddsub_pd( betaRv, yv[3], iv[3] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 + iv[0] = _mm256_permute_pd( xv[0], 5 ); + iv[1] = _mm256_permute_pd( xv[1], 5 ); + iv[2] = _mm256_permute_pd( xv[2], 5 ); + iv[3] = _mm256_permute_pd( xv[3], 5 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_pd( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaRv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaRv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( alphaRv, xv[3], yv[3] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, iv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, iv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( alphaIv, iv[3], yv[3] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + _mm256_storeu_pd( ( y0 + 3 * n_elem_per_reg ), yv[3] ); + + // Adjusting the pointers for the next iteration + y0 += 4 * n_elem_per_reg; + x0 += 4 * n_elem_per_reg; + } - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] ); - iv[2] = _mm256_fmadd_pd( alphaIv, xv[2], iv[2] ); + // Processing 6 elements per loop, 6 FMAs + for ( ; ( i + 5 ) < n; i += 6 ) + { + // Load the y vector, 6 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2 * n_elem_per_reg ); + + // Load the x vector, 6 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_pd( yv[0], 5 ); + iv[1] = _mm256_permute_pd( yv[1], 5 ); + iv[2] = _mm256_permute_pd( yv[2], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ...` + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); + iv[2] = _mm256_mul_pd( betaIv, iv[2] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + yv[2] = _mm256_fmaddsub_pd( betaRv, yv[2], iv[2] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 + iv[0] = _mm256_permute_pd( xv[0], 5 ); + iv[1] = _mm256_permute_pd( xv[1], 5 ); + iv[2] = _mm256_permute_pd( xv[2], 5 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_pd( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaRv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaRv, xv[2], yv[2] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, iv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, iv[2], yv[2] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + _mm256_storeu_pd( ( y0 + 2 * n_elem_per_reg ), yv[2] ); + + // Adjusting the pointers for the next iteration + y0 += 3 * n_elem_per_reg; + x0 += 3 * n_elem_per_reg; + } - yv[0] = _mm256_add_pd( yv[0], iv[0] ); - yv[1] = _mm256_add_pd( yv[1], iv[1] ); - yv[2] = _mm256_add_pd( yv[2], iv[2] ); + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Load the y vector, 6 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + yv[1] = _mm256_loadu_pd( y0 + 1 * n_elem_per_reg ); + + // Load the x vector, 6 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_pd( yv[0], 5 ); + iv[1] = _mm256_permute_pd( yv[1], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + iv[1] = _mm256_mul_pd( betaIv, iv[1] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + yv[1] = _mm256_fmaddsub_pd( betaRv, yv[1], iv[1] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 + iv[0] = _mm256_permute_pd( xv[0], 5 ); + iv[1] = _mm256_permute_pd( xv[1], 5 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_pd( alphaRv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaRv, xv[1], yv[1] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, iv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, iv[1], yv[1] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + _mm256_storeu_pd( ( y0 + 1 * n_elem_per_reg ), yv[1] ); + + // Adjusting the pointers for the next iteration + y0 += 2 * n_elem_per_reg; + x0 += 2 * n_elem_per_reg; + } - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); - _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + // Processing 2 elements per loop, 3 FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // Load the y vector, 6 elements in total + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 ); + + // Load the x vector, 6 elements in total + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 ); + + // Permute the vectors from y for the required compute + // iv = yI1 yR1 yI2 yR2 + iv[0] = _mm256_permute_pd( yv[0], 5 ); + + // Scale the permuted vectors with imaginary component of beta + // iv = betaIv * yv + // = yI1.bI, yR1.bI, yI2.bI, yR2.bI, ... + iv[0] = _mm256_mul_pd( betaIv, iv[0] ); + + // Using fmaddsub to scale with real component of beta + // and sub/add to iv + // yv = betaRv * yv -/+ iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmaddsub_pd( betaRv, yv[0], iv[0] ); + + // Permute the loaded vectors from x for the required compute + // xv' = xI1 xR1 xI2 xR2 + iv[0] = _mm256_permute_pd( xv[0], 5 ); + + // yv = alphaRv * xv + yv + // = yR1.bR - yR1.bI + xR1.aR, yI1.bR + yI1.bI + xI1.aR, ... + yv[0] = _mm256_fmadd_pd( alphaRv, xv[0], yv[0] ); + + // yv = alphaIv * iv + yv + // = yR1.bR - yR1.bI - xI1.aI, yI1.bR + yI1.bI + xR1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, iv[0], yv[0] ); + + // Storing the result to memory + _mm256_storeu_pd( ( y0 ), yv[0] ); + + // Adjusting the pointers for the next iteration + y0 += 1 * n_elem_per_reg; + x0 += 1 * n_elem_per_reg; + } - y0 += 3*n_elem_per_reg; - x0 += 3*n_elem_per_reg; } - // Processing 4 elements per loop, 4 FMAs - for ( ; ( i + 3 ) < n; i += 4 ) - { - // xv = xR1 xI1 xR2 xI2 - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_pd( betaRv, yv[0] ); - iv[1] = _mm256_mul_pd( betaRv, yv[1] ); - - // yv' = yI1 yR1 yI2 yR2 - yv[0] = _mm256_permute_pd( yv[0], 5); - yv[1] = _mm256_permute_pd( yv[1], 5); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); - yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions. + _mm256_zeroupper(); - // xv' = xI1 xR1 xI2 xR2 - xv[0] = _mm256_permute_pd( xv[0], 5); - xv[1] = _mm256_permute_pd( xv[1], 5); + } - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] ); - iv[1] = _mm256_fmadd_pd( alphaIv, xv[1], iv[1] ); + // Scratch registers to be used in case of non-unit strides or fringe case of 1. + __m128d x_elem, y_elem, x_perm, y_perm; + __m128d betaRv_128, betaIv_128; - yv[0] = _mm256_add_pd( yv[0], iv[0] ); - yv[1] = _mm256_add_pd( yv[1], iv[1] ); + // Casting the lower 128-bit lanes from betaRv and betaIv to its 128-bit alternative + // registers to avoid redundant broadcasts. + betaRv_128 = _mm256_castpd256_pd128( betaRv ); + betaIv_128 = _mm256_castpd256_pd128( betaIv ); - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + // NOTE : We cannot similarly use _mm256_castpd256_pd128 to avoid loading alpha + // since alpha is loaded onto its YMM rgeisters on requirement basis. + // In case of directly falling to this compute(non-unit stride cases), + // alpha wouldn't have been loaded onto any YMM reigsters. - y0 += 2*n_elem_per_reg; - x0 += 2*n_elem_per_reg; - } + // Changing betaIv_128 to { -bI bI } for the compute + x_elem = _mm_setzero_pd(); + betaIv_128 = _mm_addsub_pd( x_elem, betaIv_128 ); - // Processing 2 elements per loop, 3 FMAs - for ( ; ( i + 1 ) < n; i += 2 ) + // In case of alpha being 0, we just need to scale y by beta + if ( is_alpha_zero ) + { + // Iterate over y, one element at a time + for ( ; i < n; i += 1 ) { - // xv = xR1 xI1 xR2 xI2 - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - // yv = yR1 yI1 yR2 yI2 - yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - // iv = betaRv * yv - // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... - iv[0] = _mm256_mul_pd( betaRv, yv[0] ); - - // yv' = yI1 yR1 yI2 yR2 - yv[0] = _mm256_permute_pd( yv[0], 5); - - // yv = betaIv * yv' + iv - // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... - yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); - - // iv = alphaRv * xv - // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... - iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - - // xv' = xI1 xR1 xI2 xR2 - xv[0] = _mm256_permute_pd( xv[0], 5); - - // yv = alphaIv * xv + yv - // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... - iv[0] = _mm256_fmadd_pd( alphaIv, xv[0], iv[0] ); + // Load an element from y + // y_elem = yR1 yI1 + y_elem = _mm_loadu_pd( y0 ); - yv[0] = _mm256_add_pd( yv[0], iv[0] ); - - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); - - y0 += 1*n_elem_per_reg; - x0 += 1*n_elem_per_reg; - } - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from AVX to SSE instructions (which may occur as soon - // as the n_left cleanup loop below if BLIS is compiled with - // -mfpmath=sse). - _mm256_zeroupper(); - - if ( !bli_is_conj( conjx_use ) ) - { - for ( ; i < n ; ++i ) - { - const double yRc = *y0; - const double yIc = *( y0 + 1 ); + // Permute y in accordance to its compute + // y_perm = yI1 yR1 + y_perm = _mm_permute_pd( y_elem, 0x1 ); - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) + - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + // Scale y_perm by the imaginary + // component of beta + // y_perm = -yI1.bI, yR1.bI + y_perm = _mm_mul_pd( betaIv_128, y_perm ); - x0 += 2; - y0 += 2; - } - } - else - { - for ( ; i < n ; ++i ) - { - const double yRc = *y0; - const double yIc = *( y0 + 1 ); + // Use fmadd to scale with real component of + // beta and add with intermediate result + // y_elem = yR1.bR - yI1.bI, yI1.bR + yR1.bI + y_elem = _mm_fmadd_pd( betaRv_128, y_elem, y_perm ); - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) - - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + // Storing the result to memory + _mm_storeu_pd( y0, y_elem ); - x0 += 2; - y0 += 2; - } + // Adjusting the pointer for the next iteration + y0 += incy * 2; } } else { - // for non-unit increments, use scaler code - if ( !bli_is_conj( conjx_use ) ) + // Scratch registers to store real and imaginary components + // of alpha onto XMM registers + __m128d alphaRv_128, alphaIv_128; + + // Broadcasting real and imaginary components of alpha + x_elem = _mm_setzero_pd(); + alphaRv_128 = _mm_loaddup_pd( &alphaR ); + alphaIv_128 = _mm_loaddup_pd( &alphaI ); + + // The changes on alphaRv_128 and alphaIv_128 are as follows : + // If conjugate is required: + // alphaRv_128 = aR -aR + // Else : + // alphaIv_128 = -aI aI + if( bli_is_conj( conjx_use ) ) { - for ( i = 0; i < n ; ++i ) - { - const double yRc = *y0; - const double yIc = *( y0 + 1 ); - - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) + - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); - - x0 += incx * 2; - y0 += incy * 2; - } + alphaRv_128 = _mm_addsub_pd( x_elem, alphaRv_128 ); + alphaRv_128 = _mm_permute_pd( alphaRv_128, 0x1 ); } else { - for ( i = 0; i < n ; ++i ) - { - const double yRc = *y0; - const double yIc = *( y0 + 1 ); - - // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) - *y0 = ( betaR * yRc ) - ( betaI * yIc ) + - ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); - // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) - *(y0 + 1) = ( betaR * yIc ) + ( betaI * yRc ) - - ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + alphaIv_128 = _mm_addsub_pd( x_elem, alphaIv_128 ); + } - x0 += incx * 2; - y0 += incy * 2; - } + // Iterating over x and y vectors, on element at a time + for ( ; i < n; i += 1 ) + { + // Load an element from x and y + // y_elem = yR1 yI1 + // x_elem = xR1 xI1 + y_elem = _mm_loadu_pd( y0 ); + x_elem = _mm_loadu_pd( x0 ); + + // Permute y in accordance to its compute + // y_perm = yI1 yR1 + // x_perm = xR1 xI1 + y_perm = _mm_permute_pd( y_elem, 0x1 ); + x_perm = _mm_permute_pd( x_elem, 0x1 ); + + // Scale y_perm and x_perm by the imaginary + // component of beta and alpha + // y_perm = -yI1.bI, yR1.bI + // x_perm = -xI1.aI, xR1.aI + y_perm = _mm_mul_pd( betaIv_128, y_perm ); + x_perm = _mm_mul_pd( alphaIv_128, x_perm ); + + // Use fmadd to scale with y_elem with + // real component of beta and add with + // intermediate result. Similarly do + // for x_elem. + // y_elem = yR1.bR - yI1.bI, yI1.bR + yR1.bI + // x_elem = xR1.aR - xI1.aI, xI1.aR + xR1.aI + y_elem = _mm_fmadd_pd( betaRv_128, y_elem, y_perm ); + x_elem = _mm_fmadd_pd( alphaRv_128, x_elem, x_perm ); + + // Add the computed x and y vectors, store on y. + y_elem = _mm_add_pd( y_elem, x_elem ); + + // Storing the result to memory + _mm_storeu_pd( y0, y_elem ); + + // Adjusting the pointer for the next iteration + x0 += incx * 2; + y0 += incy * 2; } } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) } diff --git a/kernels/zen/1/bli_axpbyv_zen_int10.c b/kernels/zen/1/bli_axpbyv_zen_int10.c index 787f325ba3..02abdb4f2a 100644 --- a/kernels/zen/1/bli_axpbyv_zen_int10.c +++ b/kernels/zen/1/bli_axpbyv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1/bli_axpyv_zen_int.c b/kernels/zen/1/bli_axpyv_zen_int.c index 2b1a738da7..4b51f5a5fa 100644 --- a/kernels/zen/1/bli_axpyv_zen_int.c +++ b/kernels/zen/1/bli_axpyv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_copyv_zen_int.c b/kernels/zen/1/bli_copyv_zen_int.c index 9ffde188e8..d940cefc52 100644 --- a/kernels/zen/1/bli_copyv_zen_int.c +++ b/kernels/zen/1/bli_copyv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1/bli_dotv_zen_int.c b/kernels/zen/1/bli_dotv_zen_int.c index 145b8fe6a5..fa4a42856f 100644 --- a/kernels/zen/1/bli_dotv_zen_int.c +++ b/kernels/zen/1/bli_dotv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_dotv_zen_int10.c b/kernels/zen/1/bli_dotv_zen_int10.c index c239612006..663969a50c 100644 --- a/kernels/zen/1/bli_dotv_zen_int10.c +++ b/kernels/zen/1/bli_dotv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -53,6 +53,17 @@ typedef union double d[4] __attribute__((aligned(64))); } v4df_t; + +//Loads lower 3 64-bit double precision elements into ymm register +static int64_t mask_3[4] = {-1, -1, -1, 0}; +//Loads lower 2 64-bit double precision elements into ymm register +static int64_t mask_2[4] = {-1, -1, 0, 0}; +//Loads lower 1 64-bit double precision elements into ymm register +static int64_t mask_1[4] = {-1, 0, 0, 0}; +//Loads 4 64-bit double precision elements into ymm register +static int64_t mask_0[4] = {0, 0, 0, 0}; + +static int64_t *mask_ptr[] = {mask_0, mask_1, mask_2, mask_3}; // ----------------------------------------------------------------------------- void bli_sdotv_zen_int10 @@ -421,12 +432,15 @@ void bli_ddotv_zen_int10 y0 += 1*n_elem_per_reg; } - for ( ; (i + 0) < n; i += 1 ) + if(i < n) { - rho0 += (*x0) * (*y0); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[(n - i)]); - x0 += 1; - y0 += 1; + xv[0] = _mm256_maskload_pd( x0, maskVec ); + yv[0] = _mm256_maskload_pd( y0, maskVec ); + + rhov[0].v = _mm256_fmadd_pd( xv[0], yv[0], rhov[0].v ); + i = n; } // Manually add the results from above to finish the sum. diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index a0ddaaf549..2b96fa4d6e 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index b388dfb754..aa13f72061 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -50,6 +50,14 @@ typedef union double d[4] __attribute__( ( aligned( 64 ) ) ); } v4df_t; +// Union data structure to access SSE registers +// One 128-bit AVX register holds 2 DP elements. +typedef union +{ + __m128d v; + double d[2] __attribute__( ( aligned( 64 ) ) ); +} v2df_t; + // Return a mask which indicates either: // v <= t or v >= T #define CMP256_sf( v, t, T ) \ @@ -58,6 +66,9 @@ typedef union #define CMP256_df( v, t, T ) \ _mm256_or_pd( _mm256_cmp_pd( v, t, _CMP_LE_OS ), _mm256_cmp_pd( v, T, _CMP_GE_OS ) ); +#define CMP128_df( v, t, T ) \ + _mm_or_pd( _mm_cmp_pd( v, t, _CMP_LE_OS ), _mm_cmp_pd( v, T, _CMP_GE_OS ) ); + // Returns true if any of the values in the mask vector a is true, // and false, otherwise. // In more detail, __mm256_testz_ps() performs the bitwise (a AND b) operation and returns: @@ -75,6 +86,7 @@ typedef union // 1 (true) if the mask is true for at least one element in a. static inline bool bli_horizontal_or_sf( __m256 a ) { return ! _mm256_testz_ps( a, a ); } static inline bool bli_horizontal_or_df( __m256d a ) { return ! _mm256_testz_pd( a, a ); } +static inline bool bli_horizontal_or_df_128( __m128d a ) { return ! _mm_testz_pd( a, a ); } float horizontal_add_sf(__m256 const a) { __m256 t1 = _mm256_hadd_ps(a, a); @@ -97,57 +109,8 @@ void bli_snorm2fv_unb_var1_avx2 float sumsq = 0.0f; dim_t i = 0; - dim_t n_remainder = 0; - float *x_buf = x; - - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_membrk_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_bufX = {0}; - rntm_t rntm; - - // Packing for non-unit strided vector x. - if ( incx != 1 ) - { - // In order to get the buffer from pool via rntm access to memory broker - //is needed. Following are initializations for rntm. - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Calculate the size required for "n" float elements in vector x. - size_t buffer_size = n * sizeof( float ); - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): get mem pool block\n" ); - #endif - - // Acquire a Buffer(n*size(float)) from the memory broker - // and save the associated mem_t entry to mem_bufX. - bli_membrk_acquire_m - ( - &rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX - ); - - // Continue packing X if buffer memory is allocated. - if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) - { - x_buf = bli_mem_buffer( &mem_bufX ); - // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. - for ( dim_t x_index = 0; x_index < n; x_index++ ) - { - *( x_buf + x_index ) = *( x + ( x_index * incx ) ); - } - } - } - - float *xt = x_buf; + float *xt = x; // Compute the sum of squares on 3 accumulators to avoid overflow // and underflow, depending on the vector element value. @@ -168,7 +131,7 @@ void bli_snorm2fv_unb_var1_avx2 float abs_chi; bool isbig = false; - if ( n >= 64 ) + if ( ( n >= 64 ) && ( incx == 1 ) ) { // Constants used for comparisons. v8sf_t temp, thres_sml_vec, thres_big_vec, zerov; @@ -217,62 +180,11 @@ void bli_snorm2fv_unb_var1_avx2 mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); mask_vec2.v = _mm256_cmp_ps(x2v.v, x2v.v, _CMP_UNORD_Q); mask_vec3.v = _mm256_cmp_ps(x3v.v, x3v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) + || bli_horizontal_or_sf( mask_vec2.v ) || bli_horizontal_or_sf( mask_vec3.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec2.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec3.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -480,47 +392,10 @@ void bli_snorm2fv_unb_var1_avx2 mask_vec0.v = _mm256_cmp_ps(x0v.v, x0v.v, _CMP_UNORD_Q); mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); mask_vec2.v = _mm256_cmp_ps(x2v.v, x2v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) + || bli_horizontal_or_sf( mask_vec2.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec2.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -683,32 +558,9 @@ void bli_snorm2fv_unb_var1_avx2 // Check if any of the values is a NaN and if so, return. mask_vec0.v = _mm256_cmp_ps(x0v.v, x0v.v, _CMP_UNORD_Q); mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -825,14 +677,6 @@ void bli_snorm2fv_unb_var1_avx2 if ( bli_horizontal_or_sf( mask_vec0.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -906,75 +750,35 @@ void bli_snorm2fv_unb_var1_avx2 sum_big = horizontal_add_sf(sum_big_vec0.v); } - n_remainder = n - i; - bool hasInf = false; - - if ( ( n_remainder > 0 ) ) + // Put first the most likely to happen to avoid evaluations on if statements. + for ( ; i < n; i++) { - // Put first the most likely to happen to avoid evaluations on if statements. - for (i = 0; i < n_remainder; i++) + abs_chi = bli_fabs( *xt ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) { - abs_chi = bli_fabs( *xt ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + *norm = abs_chi; - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } - xt++; + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; } - } - // Early return if there is an Inf. - if ( hasInf ) - { - - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); + sum_med += abs_chi * abs_chi; } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + xt += incx; } // Combine accumulators. @@ -1024,15 +828,6 @@ void bli_snorm2fv_unb_var1_avx2 *norm = scale * sqrtf( sumsq ); - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_snorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -1051,57 +846,8 @@ void bli_scnorm2fv_unb_var1_avx2 float sumsq = 0.0f; dim_t i = 0; - dim_t n_remainder = 0; - scomplex *x_buf = x; - - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_membrk_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_bufX = {0}; - rntm_t rntm; - - // Packing for non-unit strided vector x. - if ( incx != 1 ) - { - // In order to get the buffer from pool via rntm access to memory broker - //is needed. Following are initializations for rntm. - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - // Calculate the size required for "n" scomplex elements in vector x. - size_t buffer_size = n * sizeof( scomplex ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): get mem pool block\n" ); - #endif - - // Acquire a Buffer(n*size(scomplex)) from the memory broker - // and save the associated mem_t entry to mem_bufX. - bli_membrk_acquire_m - ( - &rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX - ); - - // Continue packing X if buffer memory is allocated. - if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) - { - x_buf = bli_mem_buffer( &mem_bufX ); - // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. - for ( dim_t x_index = 0; x_index < n; x_index++ ) - { - *( x_buf + x_index ) = *( x + ( x_index * incx ) ); - } - } - } - - scomplex *xt = x_buf; + scomplex *xt = x; // Compute the sum of squares on 3 accumulators to avoid overflow // and underflow, depending on the vector element value. @@ -1122,7 +868,7 @@ void bli_scnorm2fv_unb_var1_avx2 float abs_chi; bool isbig = false; - if ( n >= 64 ) + if ( ( n >= 64 ) && ( incx == 1 ) ) { // Constants used for comparisons. v8sf_t temp, thres_sml_vec, thres_big_vec, zerov; @@ -1171,62 +917,10 @@ void bli_scnorm2fv_unb_var1_avx2 mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); mask_vec2.v = _mm256_cmp_ps(x2v.v, x2v.v, _CMP_UNORD_Q); mask_vec3.v = _mm256_cmp_ps(x3v.v, x3v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) + || bli_horizontal_or_sf( mask_vec2.v ) || bli_horizontal_or_sf( mask_vec3.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec2.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec3.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -1435,47 +1129,10 @@ void bli_scnorm2fv_unb_var1_avx2 mask_vec0.v = _mm256_cmp_ps(x0v.v, x0v.v, _CMP_UNORD_Q); mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); mask_vec2.v = _mm256_cmp_ps(x2v.v, x2v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec2.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) + || bli_horizontal_or_sf( mask_vec2.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -1638,32 +1295,9 @@ void bli_scnorm2fv_unb_var1_avx2 // Check if any of the values is a NaN and if so, return. mask_vec0.v = _mm256_cmp_ps(x0v.v, x0v.v, _CMP_UNORD_Q); mask_vec1.v = _mm256_cmp_ps(x1v.v, x1v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_sf( mask_vec0.v ) ) + if ( bli_horizontal_or_sf( mask_vec0.v ) || bli_horizontal_or_sf( mask_vec1.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_sf( mask_vec1.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -1779,14 +1413,6 @@ void bli_scnorm2fv_unb_var1_avx2 if ( bli_horizontal_or_sf( mask_vec0.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -1860,122 +1486,66 @@ void bli_scnorm2fv_unb_var1_avx2 sum_big = horizontal_add_sf(sum_big_vec0.v); } - n_remainder = n - i; - bool hasInf = false; double chi_r, chi_i; - if ( ( n_remainder > 0 ) ) + // Put first the most likely to happen to avoid evaluations on if statements. + for ( ; i < n; i++) { - // Put first the most likely to happen to avoid evaluations on if statements. - for (i = 0; i < n_remainder; i++) + // Get real and imaginary component of the vector element. + bli_csgets(*xt, chi_r, chi_i); + // Start with accumulating the real component of the vector element. + abs_chi = bli_fabs( chi_r ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) { - // Get real and imaginary component of the vector element. - bli_csgets(*xt, chi_r, chi_i); - // Start with accumulating the real component of the vector element. - abs_chi = bli_fabs( chi_r ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + *norm = abs_chi; - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } - // Accumulate the imaginary component of the vector element. - abs_chi = bli_fabs( chi_i ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + // Accumulate the imaginary component of the vector element. + abs_chi = bli_fabs( chi_i ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; - xt++; + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; } - } - // Early return if there is an Inf. - if ( hasInf ) - { - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; + xt += incx; } // Combine accumulators. @@ -2025,15 +1595,6 @@ void bli_scnorm2fv_unb_var1_avx2 *norm = scale * sqrtf( sumsq ); - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_scnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -2051,64 +1612,14 @@ void bli_dnorm2fv_unb_var1_avx2 AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); double sumsq = 0; - dim_t i = 0; - dim_t n_remainder = 0; - double *x_buf = x; - - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_membrk_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_bufX = {0}; - rntm_t rntm; - - // Packing for non-unit strided vector x. - if ( incx != 1 ) - { - // In order to get the buffer from pool via rntm access to memory broker - //is needed. Following are initializations for rntm. - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Calculate the size required for "n" double elements in vector x. - size_t buffer_size = n * sizeof( double ); - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1(): get mem pool block\n" ); - #endif - - // Acquire a Buffer(n*size(double)) from the memory broker - // and save the associated mem_t entry to mem_bufX. - bli_membrk_acquire_m - ( - &rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX - ); - - // Continue packing X if buffer memory is allocated. - if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) - { - x_buf = bli_mem_buffer( &mem_bufX ); - // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. - for ( dim_t x_index = 0; x_index < n; x_index++ ) - { - *( x_buf + x_index ) = *( x + ( x_index * incx ) ); - } - } - } - - double *xt = x_buf; + double *xt = x; // Compute the sum of squares on 3 accumulators to avoid overflow // and underflow, depending on the vector element value. // Accumulator for small values; using scaling to avoid underflow. double sum_sml = 0; - // Accumulator for medium values; no scaling required. + // Accumulator for medium values; no scaling required. double sum_med = 0; // Accumulator for big values; using scaling to avoid overflow. double sum_big = 0; @@ -2120,21 +1631,21 @@ void bli_dnorm2fv_unb_var1_avx2 const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP + 52 ) * 0.5 ) ); double scale; - double abs_chi; bool isbig = false; - if ( n > 4 ) - { - // Constants used for comparisons. - v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; - temp.v = _mm256_set1_pd( -0.0 ); - thres_sml_vec.v = _mm256_set1_pd( thres_sml ); - thres_big_vec.v = _mm256_set1_pd( thres_big ); - v4df_t x0v, x1v, mask_vec0, mask_vec1; - zerov.v = _mm256_setzero_pd(); + dim_t i = 0; + if( incx == 1 ) + { + // AVX-2 code-section // Partial sums used for scaling. - v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0; + v4df_t sum_med_vec1, sum_big_vec1, sum_sml_vec1; + + // Vectors used for comparisons and getting absolute values. + v4df_t thres_sml_vec, thres_big_vec, scale_sml_vec, scale_big_vec; + v4df_t temp, zerov; + sum_med_vec0.v = _mm256_setzero_pd(); sum_big_vec0.v = _mm256_setzero_pd(); sum_sml_vec0.v = _mm256_setzero_pd(); @@ -2142,51 +1653,42 @@ void bli_dnorm2fv_unb_var1_avx2 sum_big_vec1.v = _mm256_setzero_pd(); sum_sml_vec1.v = _mm256_setzero_pd(); - for (; ( i + 8 ) <= n; i = i + 8) + // Pre-broadcasting the thresholds and scale factors before entering the loops + thres_sml_vec.v = _mm256_broadcast_sd( &thres_sml ); + thres_big_vec.v = _mm256_broadcast_sd( &thres_big ); + scale_sml_vec.v = _mm256_broadcast_sd( &scale_sml ); + scale_big_vec.v = _mm256_broadcast_sd( &scale_big ); + + // This is used to convert the values in a vector to their absolute value + temp.v = _mm256_set1_pd( -0.0 ); + + // Vectors used for loading from memory and setting masks + v4df_t x0v, x1v, mask_vec0, mask_vec1; + + for ( ; ( i + 8 ) <= n; i = i + 8 ) { x0v.v = _mm256_loadu_pd( xt ); x1v.v = _mm256_loadu_pd( xt + 4 ); - // Getting the abs of the vector elements. - x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); - x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); - // Check if any of the values is a NaN and if so, return. - mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); - mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_df( mask_vec0.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + mask_vec0.v = _mm256_cmp_pd( x0v.v, x0v.v, _CMP_UNORD_Q ); + mask_vec1.v = _mm256_cmp_pd( x1v.v, x1v.v, _CMP_UNORD_Q ); - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_df( mask_vec1.v ) ) + // Checking for the presence of atleast one NaN + if ( bli_horizontal_or_df( mask_vec0.v ) || bli_horizontal_or_df( mask_vec1.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; } + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + // Mask vectors which indicate whether - // xi<=thres_sml or xi>=thres_big. + // xi <= thres_sml or xi >= thres_big. mask_vec0.v = CMP256_df( x0v.v, thres_sml_vec.v, thres_big_vec.v ); mask_vec1.v = CMP256_df( x1v.v, thres_sml_vec.v, thres_big_vec.v ); @@ -2199,39 +1701,38 @@ void bli_dnorm2fv_unb_var1_avx2 { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); if ( bli_horizontal_or_df( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_big_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec0.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec0.v ); } } } @@ -2246,38 +1747,38 @@ void bli_dnorm2fv_unb_var1_avx2 // Mask vector which indicate whether xi > thres_big. mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); + if ( bli_horizontal_or_df( mask_vec1.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); - sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + zerov.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec1.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); - ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec1.v ); + zerov.v = _mm256_mul_pd( x1v.v, zerov.v ); + sum_big_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec1.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); - sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + zerov.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec1.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); - ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec1.v ); + zerov.v = _mm256_mul_pd( x1v.v, zerov.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec1.v ); } } } @@ -2293,18 +1794,11 @@ void bli_dnorm2fv_unb_var1_avx2 x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); // Check if any of the values is a NaN and if so, return. - mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec0.v = _mm256_cmp_pd( x0v.v, x0v.v, _CMP_UNORD_Q ); + if ( bli_horizontal_or_df( mask_vec0.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -2323,39 +1817,38 @@ void bli_dnorm2fv_unb_var1_avx2 { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); if ( bli_horizontal_or_df( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_big_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec0.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec0.v ); } } } @@ -2374,74 +1867,37 @@ void bli_dnorm2fv_unb_var1_avx2 + sum_big_vec0.v[2] + sum_big_vec0.v[3]; } - n_remainder = n - i; - bool hasInf = false; - if ( ( n_remainder > 0 ) ) + // Dealing with fringe cases + double abs_chi; + for( ; i < n; i += 1 ) { - // Put first the most likely to happen to avoid evaluations on if statements. - for (i = 0; i < n_remainder; i++) + abs_chi = bli_fabs( *xt ); + // Any thread encountering a NAN sets the sum_med accumalator to NAN + if ( bli_isnan( abs_chi ) ) { - abs_chi = bli_fabs( *xt ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + *norm = NAN; - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } - xt++; + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; } - } - - // Early return if there is an Inf. - if ( hasInf ) - { - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + // Most likely case: medium values, not over/under-flow. + else if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; + xt += incx; } // Combine accumulators. @@ -2492,15 +1948,6 @@ void bli_dnorm2fv_unb_var1_avx2 *norm = scale * sqrt( sumsq ); - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -2518,64 +1965,14 @@ void bli_dznorm2fv_unb_var1_avx2 AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); double sumsq = 0; - dim_t i = 0; - dim_t n_remainder = 0; - dcomplex *x_buf = x; - - // Memory pool declarations for packing vector X. - // Initialize mem pool buffer to NULL and size to 0. - // "buf" and "size" fields are assigned once memory - // is allocated from the pool in bli_membrk_acquire_m(). - // This will ensure bli_mem_is_alloc() will be passed on - // an allocated memory if created or a NULL. - mem_t mem_bufX = {0}; - rntm_t rntm; - - // Packing for non-unit strided vector x. - if ( incx != 1 ) - { - // In order to get the buffer from pool via rntm access to memory broker - //is needed. Following are initializations for rntm. - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Calculate the size required for "n" dcomplex elements in vector x. - size_t buffer_size = n * sizeof( dcomplex ); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1(): get mem pool block\n" ); - #endif - - // Acquire a Buffer(n*size(dcomplex)) from the memory broker - // and save the associated mem_t entry to mem_bufX. - bli_membrk_acquire_m - ( - &rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX - ); - - // Continue packing X if buffer memory is allocated. - if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) - { - x_buf = bli_mem_buffer( &mem_bufX ); - // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. - for ( dim_t x_index = 0; x_index < n; x_index++ ) - { - *( x_buf + x_index ) = *( x + ( x_index * incx ) ); - } - } - } - dcomplex *xt = x_buf; + dcomplex *xt = x; // Compute the sum of squares on 3 accumulators to avoid overflow // and underflow, depending on the vector element value. // Accumulator for small values; using scaling to avoid underflow. double sum_sml = 0; - // Accumulator for medium values; no scaling required. + // Accumulator for medium values; no scaling required. double sum_med = 0; // Accumulator for big values; using scaling to avoid overflow. double sum_big = 0; @@ -2587,21 +1984,20 @@ void bli_dznorm2fv_unb_var1_avx2 const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP + 52 ) * 0.5 ) ); double scale; - double abs_chi; bool isbig = false; - if ( n > 2 ) - { - // Constants used for comparisons. - v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; - temp.v = _mm256_set1_pd( -0.0 ); - thres_sml_vec.v = _mm256_set1_pd( thres_sml ); - thres_big_vec.v = _mm256_set1_pd( thres_big ); - v4df_t x0v, x1v, mask_vec0, mask_vec1; - zerov.v = _mm256_setzero_pd(); + dim_t i = 0; + if ( incx == 1 ) + { // Partial sums used for scaling. - v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0; + v4df_t sum_med_vec1, sum_big_vec1, sum_sml_vec1; + + // Vectors used for comparisons and getting absolute values. + v4df_t thres_sml_vec, thres_big_vec, scale_sml_vec, scale_big_vec; + v4df_t temp, zerov; + sum_med_vec0.v = _mm256_setzero_pd(); sum_big_vec0.v = _mm256_setzero_pd(); sum_sml_vec0.v = _mm256_setzero_pd(); @@ -2609,51 +2005,42 @@ void bli_dznorm2fv_unb_var1_avx2 sum_big_vec1.v = _mm256_setzero_pd(); sum_sml_vec1.v = _mm256_setzero_pd(); - for (; ( i + 4 ) <= n; i = i + 4) - { - x0v.v = _mm256_loadu_pd( (double*) xt ); - x1v.v = _mm256_loadu_pd( (double*) (xt + 2) ); + // Pre-broadcasting the thresholds and scale factors before entering the loops + thres_sml_vec.v = _mm256_broadcast_sd( &thres_sml ); + thres_big_vec.v = _mm256_broadcast_sd( &thres_big ); + scale_sml_vec.v = _mm256_broadcast_sd( &scale_sml ); + scale_big_vec.v = _mm256_broadcast_sd( &scale_big ); - // Getting the abs of the vector elements. - x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); - x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + // This is used to convert the values in a vector to their absolute value + temp.v = _mm256_set1_pd( -0.0 ); + + // Vectors used for loading from memory and setting masks + v4df_t x0v, x1v, mask_vec0, mask_vec1; + + for ( ; ( i + 4 ) <= n; i += 4 ) + { + x0v.v = _mm256_loadu_pd( ( const double * )xt ); + x1v.v = _mm256_loadu_pd( ( const double * )( xt + 2 ) ); // Check if any of the values is a NaN and if so, return. - mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); - mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); - if ( bli_horizontal_or_df( mask_vec0.v ) ) - { - *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + mask_vec0.v = _mm256_cmp_pd( x0v.v, x0v.v, _CMP_UNORD_Q ); + mask_vec1.v = _mm256_cmp_pd( x1v.v, x1v.v, _CMP_UNORD_Q ); - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - if ( bli_horizontal_or_df( mask_vec1.v ) ) + // Checking for the presence of atleast one NaN + if ( bli_horizontal_or_df( mask_vec0.v ) || bli_horizontal_or_df( mask_vec1.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; } + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + // Mask vectors which indicate whether - // xi<=thres_sml or xi>=thres_big. + // xi <= thres_sml or xi >= thres_big. mask_vec0.v = CMP256_df( x0v.v, thres_sml_vec.v, thres_big_vec.v ); mask_vec1.v = CMP256_df( x1v.v, thres_sml_vec.v, thres_big_vec.v ); @@ -2666,39 +2053,38 @@ void bli_dznorm2fv_unb_var1_avx2 { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); if ( bli_horizontal_or_df( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_big_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec0.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec0.v ); } } } @@ -2713,38 +2099,38 @@ void bli_dznorm2fv_unb_var1_avx2 // Mask vector which indicate whether xi > thres_big. mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); + if ( bli_horizontal_or_df( mask_vec1.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); - sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + zerov.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec1.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); - ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec1.v ); + zerov.v = _mm256_mul_pd( x1v.v, zerov.v ); + sum_big_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec1.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); - sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + zerov.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec1.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); - ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec1.v ); + zerov.v = _mm256_mul_pd( x1v.v, zerov.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec1.v ); } } } @@ -2752,26 +2138,19 @@ void bli_dznorm2fv_unb_var1_avx2 xt += 4; } - for ( ; ( i + 2 ) <= n; i = i + 2 ) + for ( ; ( i + 2 ) <= n; i += 2 ) { - x0v.v = _mm256_loadu_pd( (double*) xt ); + x0v.v = _mm256_loadu_pd( ( const double * )xt ); // Getting the abs of the vector elements. x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); // Check if any of the values is a NaN and if so, return. - mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec0.v = _mm256_cmp_pd( x0v.v, x0v.v, _CMP_UNORD_Q ); + if ( bli_horizontal_or_df( mask_vec0.v ) ) { *norm = NAN; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; @@ -2790,39 +2169,38 @@ void bli_dznorm2fv_unb_var1_avx2 { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + zerov.v = _mm256_setzero_pd(); if ( bli_horizontal_or_df( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Fill sum_big vector using scaling. - temp.v = _mm256_set1_pd( scale_big ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_big_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_big_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_big_vec0.v ); } else { // Mask vector which indicates whether xi > thres_small. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + zerov.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_med_vec0.v ); // Accumulate small values only if there have not been any big values so far. if ( !isbig ) { // Fill sum_sml vector using scaling. - temp.v = _mm256_set1_pd( scale_sml ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); - temp.v = _mm256_set1_pd( -0.0 ); + zerov.v = _mm256_setzero_pd(); + zerov.v = _mm256_blendv_pd( zerov.v, scale_sml_vec.v, mask_vec0.v ); + zerov.v = _mm256_mul_pd( x0v.v, zerov.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( zerov.v, zerov.v, sum_sml_vec0.v ); } } } @@ -2841,125 +2219,69 @@ void bli_dznorm2fv_unb_var1_avx2 + sum_big_vec0.v[2] + sum_big_vec0.v[3]; } - n_remainder = n - i; - bool hasInf = false; + // Scalar loop to handle the fringe cases double chi_r, chi_i; - if ( ( n_remainder > 0 ) ) + double abs_chi; + for ( ; i < n; i++) { - // Put first the most likely to happen to avoid evaluations on if statements. - for (i = 0; i < n_remainder; i++) - { - // Get real and imaginary component of the vector element. - bli_zdgets(*xt, chi_r, chi_i); + // Get real and imaginary component of the vector element. + bli_zdgets(*xt, chi_r, chi_i); - // Start with accumulating the real component of the vector element. - abs_chi = bli_fabs( chi_r ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } + // Start with accumulating the real component of the vector element. + abs_chi = bli_fabs( chi_r ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; - // Accumulate the imaginary component of the vector element. - abs_chi = bli_fabs( chi_i ); - // If any of the elements is NaN, then return NaN as a result. - if ( bli_isnan( abs_chi ) ) - { - *norm = abs_chi; - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; - } - // Else, if any of the elements is an Inf, then return +Inf as a result. - if ( bli_isinf( abs_chi ) ) - { - *norm = abs_chi; - // Instead of returning immediately, use this flag - // to denote that there is an Inf element in the vector. - // That is used to avoid cases where there is a NaN which comes - // after an Inf. - hasInf = true; - } - // Most likely case: medium values, not over/under-flow. - if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) - { - sum_med += abs_chi * abs_chi; - } - // Case where there could be an overflow. Scaling is required. - else if ( abs_chi > thres_big ) - { - sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); - isbig = true; - } - // Case where there could be an underflow. Scaling is required. - else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) - { - sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); - } + // Accumulate the imaginary component of the vector element. + abs_chi = bli_fabs( chi_i ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; - xt++; + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; } - } - - // Early return if there is an Inf. - if ( hasInf ) - { - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dnorm2fv_unb_var1_avx2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); - return; + xt += incx; } // Combine accumulators. @@ -3001,6 +2323,7 @@ void bli_dznorm2fv_unb_var1_avx2 sumsq = sum_sml; } } + else { // If all values are mid-range: @@ -3010,15 +2333,6 @@ void bli_dznorm2fv_unb_var1_avx2 *norm = scale * sqrt( sumsq ); - if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dznorm2fv_unb_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool. - bli_membrk_release( &rntm , &mem_bufX ); - } - AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; diff --git a/kernels/zen/1/bli_scalv_zen_int.c b/kernels/zen/1/bli_scalv_zen_int.c index 9f76e88e18..fa337c247f 100644 --- a/kernels/zen/1/bli_scalv_zen_int.c +++ b/kernels/zen/1/bli_scalv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 2d96d756c1..e760367060 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -609,8 +609,8 @@ void bli_zdscalv_zen_int10 for ( ; ( i + 29 ) < n; i += 30 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); @@ -641,8 +641,8 @@ void bli_zdscalv_zen_int10 xv[13] = _mm256_mul_pd( alphav, xv[13] ); xv[14] = _mm256_mul_pd( alphav, xv[14] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( x0, xv[0] ); + _mm256_storeu_pd( (x0 + n_elem_per_reg), xv[1] ); _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); @@ -662,8 +662,8 @@ void bli_zdscalv_zen_int10 for ( ; ( i + 23 ) < n; i += 24 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); @@ -688,8 +688,8 @@ void bli_zdscalv_zen_int10 xv[10] = _mm256_mul_pd( alphav, xv[10] ); xv[11] = _mm256_mul_pd( alphav, xv[11] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( x0, xv[0] ); + _mm256_storeu_pd( (x0 + n_elem_per_reg), xv[1] ); _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); @@ -706,8 +706,8 @@ void bli_zdscalv_zen_int10 for ( ; ( i + 15 ) < n; i += 16 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); xv[4] = _mm256_loadu_pd( x0 + 4 * n_elem_per_reg ); @@ -724,8 +724,8 @@ void bli_zdscalv_zen_int10 xv[6] = _mm256_mul_pd( alphav, xv[6] ); xv[7] = _mm256_mul_pd( alphav, xv[7] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( x0, xv[0] ); + _mm256_storeu_pd( (x0 + n_elem_per_reg), xv[1] ); _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); @@ -738,8 +738,8 @@ void bli_zdscalv_zen_int10 for ( ; ( i + 7 ) < n; i += 8 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); xv[2] = _mm256_loadu_pd( x0 + 2 * n_elem_per_reg ); xv[3] = _mm256_loadu_pd( x0 + 3 * n_elem_per_reg ); @@ -748,8 +748,8 @@ void bli_zdscalv_zen_int10 xv[2] = _mm256_mul_pd( alphav, xv[2] ); xv[3] = _mm256_mul_pd( alphav, xv[3] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( x0, xv[0] ); + _mm256_storeu_pd( (x0 + n_elem_per_reg), xv[1] ); _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); @@ -758,35 +758,27 @@ void bli_zdscalv_zen_int10 for ( ; ( i + 3 ) < n; i += 4 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); + xv[1] = _mm256_loadu_pd( x0 + n_elem_per_reg ); xv[0] = _mm256_mul_pd( alphav, xv[0] ); xv[1] = _mm256_mul_pd( alphav, xv[1] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( x0, xv[0] ); + _mm256_storeu_pd( (x0 + n_elem_per_reg), xv[1] ); x0 += 2 * n_elem_per_reg; } for ( ; ( i + 1 ) < n; i += 2 ) { - xv[0] = _mm256_loadu_pd( x0 + 0 * n_elem_per_reg ); + xv[0] = _mm256_loadu_pd( x0 ); xv[0] = _mm256_mul_pd( alphav, xv[0] ); - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( x0, xv[0] ); - x0 += 1 * n_elem_per_reg; - } - - for ( ; i < n; i++ ) - { - ( *x0 ) *= alphac; - ( *( x0 + 1 ) ) *= alphac; - - x0 += 2 * incx; + x0 += n_elem_per_reg; } // Issue vzeroupper instruction to clear upper lanes of ymm registers. @@ -796,15 +788,22 @@ void bli_zdscalv_zen_int10 // -mfpmath=sse). _mm256_zeroupper(); } - else + + /* In double complex data type the computation of + unit stride elements can still be vectorized using SSE*/ + __m128d alpha_reg, x_vec; + + alpha_reg = _mm_set1_pd((*alpha).real); + + for (; i < n; ++i) { - for ( ; i < n; ++i ) - { - ( *x0 ) *= alphac; - ( *( x0 + 1 ) ) *= alphac; + x_vec = _mm_loadu_pd(x0); - x0 += 2 * incx; - } + x_vec = _mm_mul_pd(x_vec, alpha_reg); + + _mm_storeu_pd(x0, x_vec); + + x0 += 2 * incx; } } @@ -817,10 +816,20 @@ void bli_zscalv_zen_int cntx_t* restrict cntx ) { - // If the vector dimension is zero, or if alpha is unit, return early. - if (bli_zero_dim1(n) || PASTEMAC(z, eq1)(*alpha)) - return; + /* + Undefined behaviour + ------------------- + + 1. This layer is not BLAS complaint and the kernel results in + undefined behaviour when n <= 0 and incx <= 1. The expectation + is that the application/higher-layer invoking this layer should + the arg checks. + */ + // if (bli_zero_dim1(n) || PASTEMAC(z, eq1)(*alpha)) + // return; + // To Do: This call to SETV needs to be removed for BLAS compliance + // Currently removing this is resulting in ZHERK failures if (PASTEMAC(z, eq0)(*alpha)) { // Expert interface of setv is invoked when alpha is zero @@ -834,8 +843,7 @@ void bli_zscalv_zen_int zero, x, incx, cntx, - NULL - ); + NULL); return; } @@ -896,33 +904,29 @@ void bli_zscalv_zen_int x_vec_ymm[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); x_vec_ymm[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); - temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_real_ymm); - temp_ymm[1] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); - temp_ymm[2] = _mm256_mul_pd(x_vec_ymm[1], alpha_real_ymm); - temp_ymm[3] = _mm256_mul_pd(x_vec_ymm[1], alpha_imag_ymm); - temp_ymm[4] = _mm256_mul_pd(x_vec_ymm[2], alpha_real_ymm); - temp_ymm[5] = _mm256_mul_pd(x_vec_ymm[2], alpha_imag_ymm); - temp_ymm[6] = _mm256_mul_pd(x_vec_ymm[3], alpha_real_ymm); - temp_ymm[7] = _mm256_mul_pd(x_vec_ymm[3], alpha_imag_ymm); + temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); + temp_ymm[1] = _mm256_mul_pd(x_vec_ymm[1], alpha_imag_ymm); + temp_ymm[2] = _mm256_mul_pd(x_vec_ymm[2], alpha_imag_ymm); + temp_ymm[3] = _mm256_mul_pd(x_vec_ymm[3], alpha_imag_ymm); - temp_ymm[1] = _mm256_permute_pd(temp_ymm[1], 0b0101); - temp_ymm[3] = _mm256_permute_pd(temp_ymm[3], 0b0101); - temp_ymm[5] = _mm256_permute_pd(temp_ymm[5], 0b0101); - temp_ymm[7] = _mm256_permute_pd(temp_ymm[7], 0b0101); + temp_ymm[4] = _mm256_permute_pd(temp_ymm[0], 0b0101); + temp_ymm[5] = _mm256_permute_pd(temp_ymm[1], 0b0101); + temp_ymm[6] = _mm256_permute_pd(temp_ymm[2], 0b0101); + temp_ymm[7] = _mm256_permute_pd(temp_ymm[3], 0b0101); /* - a[i+63:i] := b[i+63:i] - c[i+63:i] for odd indices - a[i+63:i] := b[i+63:i] + c[i+63:i] for even indices + a[i+63:i] := alpha_real * b[i+63:i] - c[i+63:i] for odd indices + a[i+63:i] := alpha_real * b[i+63:i] + c[i+63:i] for even indices */ - temp_ymm[0] = _mm256_addsub_pd(temp_ymm[0], temp_ymm[1]); - temp_ymm[2] = _mm256_addsub_pd(temp_ymm[2], temp_ymm[3]); - temp_ymm[4] = _mm256_addsub_pd(temp_ymm[4], temp_ymm[5]); - temp_ymm[6] = _mm256_addsub_pd(temp_ymm[6], temp_ymm[7]); + temp_ymm[0] = _mm256_fmaddsub_pd(x_vec_ymm[0], alpha_real_ymm, temp_ymm[4]); + temp_ymm[1] = _mm256_fmaddsub_pd(x_vec_ymm[1], alpha_real_ymm, temp_ymm[5]); + temp_ymm[2] = _mm256_fmaddsub_pd(x_vec_ymm[2], alpha_real_ymm, temp_ymm[6]); + temp_ymm[3] = _mm256_fmaddsub_pd(x_vec_ymm[3], alpha_real_ymm, temp_ymm[7]); _mm256_storeu_pd(x0, temp_ymm[0]); - _mm256_storeu_pd(x0 + n_elem_per_reg, temp_ymm[2]); - _mm256_storeu_pd(x0 + 2 * n_elem_per_reg, temp_ymm[4]); - _mm256_storeu_pd(x0 + 3 * n_elem_per_reg, temp_ymm[6]); + _mm256_storeu_pd(x0 + n_elem_per_reg, temp_ymm[1]); + _mm256_storeu_pd(x0 + 2 * n_elem_per_reg, temp_ymm[2]); + _mm256_storeu_pd(x0 + 3 * n_elem_per_reg, temp_ymm[3]); x0 += 4 * n_elem_per_reg; } @@ -932,19 +936,17 @@ void bli_zscalv_zen_int x_vec_ymm[0] = _mm256_loadu_pd(x0); x_vec_ymm[1] = _mm256_loadu_pd(x0 + n_elem_per_reg); - temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_real_ymm); - temp_ymm[1] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); - temp_ymm[2] = _mm256_mul_pd(x_vec_ymm[1], alpha_real_ymm); - temp_ymm[3] = _mm256_mul_pd(x_vec_ymm[1], alpha_imag_ymm); + temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); + temp_ymm[1] = _mm256_mul_pd(x_vec_ymm[1], alpha_imag_ymm); - temp_ymm[1] = _mm256_permute_pd(temp_ymm[1], 0b0101); - temp_ymm[3] = _mm256_permute_pd(temp_ymm[3], 0b0101); + temp_ymm[2] = _mm256_permute_pd(temp_ymm[0], 0b0101); + temp_ymm[3] = _mm256_permute_pd(temp_ymm[1], 0b0101); - temp_ymm[0] = _mm256_addsub_pd(temp_ymm[0], temp_ymm[1]); - temp_ymm[2] = _mm256_addsub_pd(temp_ymm[2], temp_ymm[3]); + temp_ymm[0] = _mm256_fmaddsub_pd(x_vec_ymm[0], alpha_real_ymm, temp_ymm[2]); + temp_ymm[1] = _mm256_fmaddsub_pd(x_vec_ymm[1], alpha_real_ymm, temp_ymm[3]); _mm256_storeu_pd(x0, temp_ymm[0]); - _mm256_storeu_pd(x0 + n_elem_per_reg, temp_ymm[2]); + _mm256_storeu_pd(x0 + n_elem_per_reg, temp_ymm[1]); x0 += 2 * n_elem_per_reg; } @@ -953,44 +955,42 @@ void bli_zscalv_zen_int { x_vec_ymm[0] = _mm256_loadu_pd(x0); - temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_real_ymm); - temp_ymm[1] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); + temp_ymm[0] = _mm256_mul_pd(x_vec_ymm[0], alpha_imag_ymm); - temp_ymm[1] = _mm256_permute_pd(temp_ymm[1], 0b0101); + temp_ymm[1] = _mm256_permute_pd(temp_ymm[0], 0b0101); - temp_ymm[0] = _mm256_addsub_pd(temp_ymm[0], temp_ymm[1]); + temp_ymm[0] = _mm256_fmaddsub_pd(x_vec_ymm[0], alpha_real_ymm, temp_ymm[1]); _mm256_storeu_pd(x0, temp_ymm[0]); x0 += n_elem_per_reg; } - } - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from AVX to SSE instructions (which may occur later, - // especially if BLIS is compiled with -mfpmath=sse). - _mm256_zeroupper(); + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur later, + // especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + } /* In double complex data type the computation of unit stride elements can still be vectorized using SSE*/ - __m128d temp_ymm[2], alpha_real_ymm, alpha_imag_ymm, x_vec_ymm; + __m128d temp_xmm[2], alpha_real_xmm, alpha_imag_xmm, x_vec_xmm; - alpha_real_ymm = _mm_set1_pd(real); - alpha_imag_ymm = _mm_set1_pd(imag); + alpha_real_xmm = _mm_set1_pd(real); + alpha_imag_xmm = _mm_set1_pd(imag); for (; i < n; i++) { - x_vec_ymm = _mm_loadu_pd(x0); + x_vec_xmm = _mm_loadu_pd(x0); - temp_ymm[0] = _mm_mul_pd(x_vec_ymm, alpha_real_ymm); - temp_ymm[1] = _mm_mul_pd(x_vec_ymm, alpha_imag_ymm); + temp_xmm[0] = _mm_permute_pd(x_vec_xmm, 0b01); - temp_ymm[1] = _mm_permute_pd(temp_ymm[1], 0b01); + temp_xmm[1] = _mm_mul_pd(temp_xmm[0], alpha_imag_xmm); - temp_ymm[0] = _mm_addsub_pd(temp_ymm[0], temp_ymm[1]); + temp_xmm[0] = _mm_fmaddsub_pd(x_vec_xmm, alpha_real_xmm, temp_xmm[1]); - _mm_storeu_pd(x0, temp_ymm[0]); + _mm_storeu_pd(x0, temp_xmm[0]); x0 += 2 * incx; } diff --git a/kernels/zen/1/bli_setv_zen_int.c b/kernels/zen/1/bli_setv_zen_int.c index 16e02c94da..5ebd061cdd 100644 --- a/kernels/zen/1/bli_setv_zen_int.c +++ b/kernels/zen/1/bli_setv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1/bli_swapv_zen_int8.c b/kernels/zen/1/bli_swapv_zen_int8.c index 205638e8c9..ba7c92593c 100644 --- a/kernels/zen/1/bli_swapv_zen_int8.c +++ b/kernels/zen/1/bli_swapv_zen_int8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt deleted file mode 100644 index 5da0c9e7b0..0000000000 --- a/kernels/zen/1f/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen_1f - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxf_zen_int_8.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_5.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c - ) -target_compile_options(zen_1f PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen_1f PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index ba92066a43..9d0d42dd3d 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index 36d94712aa..43236887d9 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index 8fea5f6498..6b23fe6c45 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -667,7 +667,7 @@ void bli_daxpyf_zen_int_5 // ----------------------------------------------------------------------------- -static void bli_daxpyf_zen_int_16x2 +void bli_daxpyf_zen_int_16x2 ( conj_t conja, conj_t conjx, @@ -1003,21 +1003,6 @@ void bli_daxpyf_zen_int_16x4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { - if (b_n & 2) - { - bli_daxpyf_zen_int_16x2( conja, - conjx, - m, 2, - alpha, a, inca, lda, - x, incx, - y, incy, - cntx - ); - b_n -= 2; - a += 2*lda; - x += 2 * incx; - } - daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index 6da5d99e6d..27cf3b7d89 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -185,7 +185,7 @@ void bli_saxpyf_zen_int_6 // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { - float y0c = *y0; + double y0c = *y0; const float a0c = *a0; const float a1c = *(a0+ 1*lda); @@ -211,7 +211,7 @@ void bli_saxpyf_zen_int_6 { for ( i = 0; (i + 0) < m ; ++i ) { - float y0c = *y0; + double y0c = *y0; const float a0c = *a0; const float a1c = *(a0+ 1*lda); const float a2c = *(a0+ 2*lda); diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index 27dafb28fc..3da593cf74 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -296,17 +296,91 @@ void bli_daxpyf_zen_int_8 // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return; - // If b_n is not equal to the fusing factor, then perform the entire - // operation as a loop over axpyv. + /* + If b_n is not equal to the fusing factor, then perform the entire + operation as axpyv or perform the operation using axpyf kernels with + lower fuse factor. + */ if ( b_n != fuse_fac ) { - daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); + if (b_n >= 5) + { + dim_t fuse_fac = 5; - for ( i = 0; i < b_n; ++i ) + bli_daxpyf_zen_int_5 + ( + conja, + conjx, + m, + fuse_fac, + alpha, + a, inca, lda, + x, incx, + y, incy, + cntx + ); + + a = a + (fuse_fac * lda); + x = x + (fuse_fac * incx); + + b_n -= fuse_fac; + } + + if (b_n == 4) { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; + dim_t fuse_fac = 4; + + bli_daxpyf_zen_int_16x4 + ( + conja, + conjx, + m, + fuse_fac, + alpha, + a, inca, lda, + x, incx, + y, incy, + cntx + ); + + a = a + (fuse_fac * lda); + x = x + (fuse_fac * incx); + + b_n -= fuse_fac; + } + + if (b_n >= 2) + { + dim_t fuse_fac = 2; + + bli_daxpyf_zen_int_16x2 + ( + conja, + conjx, + m, fuse_fac, + alpha, a, inca, lda, + x, incx, + y, incy, + cntx + ); + + a = a + (fuse_fac * lda); + x = x + (fuse_fac * incx); + + b_n -= fuse_fac; + + } + + if (b_n == 1) + { + // Query the context if it is NULL. This will be necessary for Zen architectures + if (cntx == NULL) cntx = bli_gks_query_cntx(); + + daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_AXPYV_KER, cntx); + + double* a1 = a; + double* chi1 = x; + double* y1 = y; double alpha_chi1; PASTEMAC(d,copycjs)( conjx, *chi1, alpha_chi1 ); diff --git a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c index ba92d493ea..fbd354593c 100644 --- a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 815e388f21..bb39992de8 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2017 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -458,15 +458,69 @@ void bli_ddotxf_zen_int_8 return; } - // If b_n is not equal to the fusing factor, then perform the entire - // operation as a loop over dotxv. + /* + If b_n is not equal to the fusing factor, then perform the entire + operation as dotxv or perform the operation using dotxf kernels with + lower fuse factor. + */ if (b_n != fuse_fac) { - for (dim_t i = 0; i < b_n; ++i) + if (b_n >= 4) { - double *a1 = a + (0) * inca + (i)*lda; - double *x1 = x + (0) * incx; - double *psi1 = y + (i)*incy; + dim_t fuse = 4; + + bli_ddotxf_zen_int_4 + ( + conjat, + conjx, + m, + fuse, + alpha, + a, inca, lda, + x, incx, + beta, + y, incy, + cntx + ); + + // Increment the pointers + a = a + (fuse)*lda; + y = y + (fuse)*incy; + + // Decrement to point to the remaining compute left + b_n -= 4; + } + + if (b_n >= 2) + { + dim_t fuse = 2; + + bli_ddotxf_zen_int_2 + ( + conjat, + conjx, + m, + fuse, + alpha, + a, inca, lda, + x, incx, + beta, + y, incy, + cntx + ); + + // Increment the pointers + a = a + (fuse)*lda; + y = y + (fuse)*incy; + + b_n -= 2; + } + + if (b_n == 1) + { + double *a1 = a; + double *x1 = x; + double *psi1 = y; bli_ddotxv_zen_int( conjat, @@ -1562,419 +1616,458 @@ void bli_zdotxf_zen_int_6 cntx_t* restrict cntx ) { - /** - * Handles only unit stride cases and 6 column at a time - * b_n check for columns to be 6. - */ - if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + /* If the vectors are empty or if alpha is zero, return early */ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) { - /* Temporary rho buffer holds computed dot product result */ - dcomplex r[ 6 ]; + bli_zscalv_zen_int + ( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx + ); - /* If beta is zero, clear y. Otherwise, scale by beta. */ - if ( PASTEMAC(z,eq0)( *beta ) ) - { - for ( dim_t i = 0; i < 6; ++i ) - { - PASTEMAC(z,set0s)( y[i] ); - } - } - else + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if ( b_n != 6 ) + { + for ( dim_t i = 0; i < b_n; ++i ) { - for ( dim_t i = 0; i < 6; ++i ) - { - PASTEMAC(z,scals)( *beta, y[i] ); - } + dcomplex* restrict a1 = a + (0 )*inca + (i )*lda; + dcomplex* restrict x1 = x + (0 )*incx; + dcomplex* restrict psi1 = y + (i )*incy; + + bli_zdotxv_zen_int + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); } - /* If the vectors are empty or if alpha is zero, return early*/ - if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + return; + } - /* Initialize r vector to 0. */ - for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(z,set0s)( r[i] ); + dim_t rem = m; - /* If a must be conjugated, we do so indirectly by first - * toggling the effective conjugation of x and then conjugating - * the resulting do products. - * Rather conjugating each element of a matrix, final computed result - * can be conjugated at the end of loop. This takes off the overhead - * of conjugating each element inside the loop and improves the - * performance. - */ - conj_t conjx_use = conjx; + double *restrict av[6]; + double *restrict x_temp = (double *)(x); - if ( bli_is_conj( conjat ) ) - { - bli_toggle_conj( &conjx_use ); - } + av[0] = (double *)(a + 0 * lda); + av[1] = (double *)(a + 1 * lda); + av[2] = (double *)(a + 2 * lda); + av[3] = (double *)(a + 3 * lda); + av[4] = (double *)(a + 4 * lda); + av[5] = (double *)(a + 5 * lda); - /* Setting rho vectors to 0 */ - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + dcomplex res[6]; - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); - v4df_t rho8v; rho8v.v = _mm256_setzero_pd(); - v4df_t rho9v; rho9v.v = _mm256_setzero_pd(); - v4df_t rho10v; rho10v.v = _mm256_setzero_pd(); - v4df_t rho11v; rho11v.v = _mm256_setzero_pd(); + res[0] = res[1] = res[2] = res[3] = res[4] = res[5] = (*bli_z0); - /* Holds 2 dcomplex element of x vector - * for computing dot product with A tile - */ - v4df_t x0v, x1v; - /* Holds 2x6 tile of matrix A */ - v4df_t a0v, a1v, a2v, a3v, a4v, a5v; - /** - * Since complex datatype multiplication is - * being held in two sets of rho vectors. - * Where first set holds the computaion with - * real part of vector x and other holds - * imaginary part of vector x. - * For final computation, based on conj sign - * of imaginary component needs to be toggled. - */ - __m256d no_conju = _mm256_setr_pd(-1, 1, -1, 1); - __m256d conju = _mm256_setr_pd(1, -1, 1, -1); - dim_t iter = m / 2; - dim_t rem = m % 2; - dim_t i = 0; + conj_t conjx_use = conjx; - if ( bli_is_noconj( conjx_use ) ) + if (bli_is_conj(conjat)) + { + bli_toggle_conj(&conjx_use); + } + + if (incx == 1 && inca == 1) + { + rem = m % 2; + v4df_t rhov[12], a_vec[6], xv[2], conj_mul; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + rhov[8].v = _mm256_setzero_pd(); + rhov[9].v = _mm256_setzero_pd(); + rhov[10].v = _mm256_setzero_pd(); + rhov[11].v = _mm256_setzero_pd(); + + for (dim_t i = 0; (i + 1) < m; i += 2) { - if(iter) - { - for ( ; (i+1) < m; i+=2) - { - /*Load 2 dcomplex elements from - * vector x - */ - x0v.v = _mm256_loadu_pd( - (double *)(x + i) ); - /* x1v.v holds imaginary part of dcomplex - * elements from vector x - * It will do following operation. - * R0 I0 R1 I1 => I0 I0 I1 I1 - * - */ - x1v.v = _mm256_permute_pd( x0v.v, 15 ); - /* x1v.v holds real part of dcomplex - * elements from vector x - * It will do following operation. - * R0 I0 R1 I1 => R0 R0 R1 R1 - */ - x0v.v = _mm256_permute_pd( x0v.v, 0 ); - - /*Load 2x6 tile of matrix A*/ - a0v.v = _mm256_loadu_pd( (double *) - (a + i + 0 * lda) ); - a1v.v = _mm256_loadu_pd( (double *) - (a + i + 1 * lda) ); - a2v.v = _mm256_loadu_pd( (double *) - (a + i + 2 * lda) ); - a3v.v = _mm256_loadu_pd( (double *) - (a + i + 3 * lda) ); - a4v.v = _mm256_loadu_pd( (double *) - (a + i + 4 * lda) ); - a5v.v = _mm256_loadu_pd( (double *) - (a + i + 5 * lda) ); - - // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, - x0v.v, rho0v.v ); - rho6v.v = _mm256_fmadd_pd( a0v.v, - x1v.v, rho6v.v ); - - rho1v.v = _mm256_fmadd_pd( a1v.v, - x0v.v, rho1v.v ); - rho7v.v = _mm256_fmadd_pd( a1v.v, - x1v.v, rho7v.v ); - - rho2v.v = _mm256_fmadd_pd( a2v.v, - x0v.v, rho2v.v ); - rho8v.v = _mm256_fmadd_pd( a2v.v, - x1v.v, rho8v.v ); - - rho3v.v = _mm256_fmadd_pd( a3v.v, - x0v.v, rho3v.v ); - rho9v.v = _mm256_fmadd_pd( a3v.v, - x1v.v, rho9v.v ); - - rho4v.v = _mm256_fmadd_pd( a4v.v, - x0v.v, rho4v.v ); - rho10v.v = _mm256_fmadd_pd( a4v.v, - x1v.v, rho10v.v ); - - rho5v.v = _mm256_fmadd_pd( a5v.v, - x0v.v, rho5v.v ); - rho11v.v = _mm256_fmadd_pd( a5v.v, - x1v.v, rho11v.v ); - } + // Load 2 dcomplex elements from vector x + xv[0].v = _mm256_loadu_pd(x_temp); - /*Swapping position of real and imag component - * for horizontal addition to get the final - * dot product computation - * rho register are holding computation which needs - * to be arranged in following manner. - * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 - * || - * \/ - * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 - */ - rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); - rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); - rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); - rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); - rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); - rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + // xv[1].v - R0 I0 R1 I1 => I0 I0 I1 I1 + xv[1].v = _mm256_permute_pd(xv[0].v, 15); - /*Negating imaginary part for computing - * the final result of dcomplex multiplication - */ - rho6v.v = _mm256_mul_pd(rho6v.v, no_conju); - rho7v.v = _mm256_mul_pd(rho7v.v, no_conju); - rho8v.v = _mm256_mul_pd(rho8v.v, no_conju); - rho9v.v = _mm256_mul_pd(rho9v.v, no_conju); - rho10v.v = _mm256_mul_pd(rho10v.v, no_conju); - rho11v.v = _mm256_mul_pd(rho11v.v, no_conju); - - rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); - rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); - rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); - rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); - rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); - rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); - - /*rho0, rho1, rho2 holds final dot product - * result of 6 dcomplex elements. - */ - rho0v.d[0] += rho0v.d[2]; - rho0v.d[1] += rho0v.d[3]; + // xv[0].v - R0 I0 R1 I1 => R0 R0 R1 R1 + xv[0].v = _mm256_permute_pd(xv[0].v, 0); - rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; - rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + a_vec[0].v = _mm256_loadu_pd((double *)(av[0])); + a_vec[1].v = _mm256_loadu_pd((double *)(av[1])); + a_vec[2].v = _mm256_loadu_pd((double *)(av[2])); + a_vec[3].v = _mm256_loadu_pd((double *)(av[3])); + a_vec[4].v = _mm256_loadu_pd((double *)(av[4])); + a_vec[5].v = _mm256_loadu_pd((double *)(av[5])); - rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; - rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + // perform: rho?v += a?v * xv[0]; + rhov[0].v = _mm256_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[6].v = _mm256_fmadd_pd(a_vec[0].v, xv[1].v, rhov[6].v); - rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; - rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + rhov[1].v = _mm256_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[7].v = _mm256_fmadd_pd(a_vec[1].v, xv[1].v, rhov[7].v); - rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; - rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + rhov[2].v = _mm256_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[8].v = _mm256_fmadd_pd(a_vec[2].v, xv[1].v, rhov[8].v); - rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; - rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + rhov[3].v = _mm256_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[9].v = _mm256_fmadd_pd(a_vec[3].v, xv[1].v, rhov[9].v); - /*Computed dot product result is being stored - * in temp buffer r for further computation. - */ - _mm256_storeu_pd((double *)r, rho0v.v); - _mm256_storeu_pd((double *)(r+2) , rho1v.v); - _mm256_storeu_pd((double *)(r+4) , rho2v.v); + rhov[4].v = _mm256_fmadd_pd(a_vec[4].v, xv[0].v, rhov[4].v); + rhov[10].v = _mm256_fmadd_pd(a_vec[4].v, xv[1].v, rhov[10].v); - } - /*handles remainder cases*/ - if(rem) - { - PRAGMA_SIMD - for(dim_t p = 0; p < 6 ; p++) - { - PASTEMAC(z,axpys)( a[i + p*lda] - , x[i], r[p] ); - } - } + rhov[5].v = _mm256_fmadd_pd(a_vec[5].v, xv[0].v, rhov[5].v); + rhov[11].v = _mm256_fmadd_pd(a_vec[5].v, xv[1].v, rhov[11].v); + + av[0] += 4; + av[1] += 4; + av[2] += 4; + av[3] += 4; + av[4] += 4; + av[5] += 4; + + x_temp += 4; + } + + if (bli_is_noconj(conjx_use)) + { + conj_mul.v = _mm256_setr_pd(-1, 1, -1, 1); } else { - if(iter) - { - for ( ; (i+1) < m; i+=2) - { - /*Load 2 dcomplex elements from - * vector x - */ - x0v.v = _mm256_loadu_pd( (double *) - (x + i) ); - /* x1v.v holds imaginary part of dcomplex - * elements from vector x - */ - x1v.v = _mm256_permute_pd( x0v.v, 15 ); - /* x1v.v holds real part of dcomplex - * elements from vector x - */ - x0v.v = _mm256_permute_pd( x0v.v, 0 ); - - /*Load 2x6 tile of matrix A*/ - a0v.v = _mm256_loadu_pd( (double *) - (a + i + 0 * lda)); - a1v.v = _mm256_loadu_pd( (double *) - (a + i + 1 * lda)); - a2v.v = _mm256_loadu_pd( (double *) - (a + i + 2 * lda)); - a3v.v = _mm256_loadu_pd( (double *) - (a + i + 3 * lda)); - a4v.v = _mm256_loadu_pd( (double *) - (a + i + 4 * lda)); - a5v.v = _mm256_loadu_pd( (double *) - (a + i + 5 * lda)); - - // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, - x0v.v, rho0v.v ); - rho6v.v = _mm256_fmadd_pd( a0v.v, - x1v.v, rho6v.v ); - - rho1v.v = _mm256_fmadd_pd( a1v.v, - x0v.v, rho1v.v ); - rho7v.v = _mm256_fmadd_pd( a1v.v, - x1v.v, rho7v.v ); - - rho2v.v = _mm256_fmadd_pd( a2v.v, - x0v.v, rho2v.v ); - rho8v.v = _mm256_fmadd_pd( a2v.v, - x1v.v, rho8v.v ); - - rho3v.v = _mm256_fmadd_pd( a3v.v, - x0v.v, rho3v.v ); - rho9v.v = _mm256_fmadd_pd( a3v.v, - x1v.v, rho9v.v ); - - rho4v.v = _mm256_fmadd_pd( a4v.v, - x0v.v, rho4v.v ); - rho10v.v = _mm256_fmadd_pd( a4v.v, - x1v.v, rho10v.v ); - - rho5v.v = _mm256_fmadd_pd( a5v.v, - x0v.v, rho5v.v ); - rho11v.v = _mm256_fmadd_pd( a5v.v, - x1v.v, rho11v.v ); - } + conj_mul.v = _mm256_setr_pd(1, -1, 1, -1); + } - /*Swapping position of real and imag component - * for horizontal addition to get the final - * dot product computation - * rho register are holding computation which needs - * to be arranged in following manner. - * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 - * || - * \/ - * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 - */ - rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); - rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); - rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); - rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); - rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); - rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rhov[6].v = _mm256_permute_pd(rhov[6].v, 0x05); + rhov[7].v = _mm256_permute_pd(rhov[7].v, 0x05); + rhov[8].v = _mm256_permute_pd(rhov[8].v, 0x05); + rhov[9].v = _mm256_permute_pd(rhov[9].v, 0x05); + rhov[10].v = _mm256_permute_pd(rhov[10].v, 0x05); + rhov[11].v = _mm256_permute_pd(rhov[11].v, 0x05); + + /* + Modifying the imag sign according to the conj value + */ + rhov[6].v = _mm256_mul_pd(rhov[6].v, conj_mul.v); + rhov[7].v = _mm256_mul_pd(rhov[7].v, conj_mul.v); + rhov[8].v = _mm256_mul_pd(rhov[8].v, conj_mul.v); + rhov[9].v = _mm256_mul_pd(rhov[9].v, conj_mul.v); + rhov[10].v = _mm256_mul_pd(rhov[10].v, conj_mul.v); + rhov[11].v = _mm256_mul_pd(rhov[11].v, conj_mul.v); + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[6].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[7].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[8].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[9].v); + rhov[4].v = _mm256_add_pd(rhov[4].v, rhov[10].v); + rhov[5].v = _mm256_add_pd(rhov[5].v, rhov[11].v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rhov[0].d[0] += rhov[0].d[2]; + rhov[0].d[1] += rhov[0].d[3]; - /*Negating imaginary part for computing - * the final result of dcomplex multiplication - */ - rho6v.v = _mm256_mul_pd(rho6v.v, conju); - rho7v.v = _mm256_mul_pd(rho7v.v, conju); - rho8v.v = _mm256_mul_pd(rho8v.v, conju); - rho9v.v = _mm256_mul_pd(rho9v.v, conju); - rho10v.v = _mm256_mul_pd(rho10v.v, conju); - rho11v.v = _mm256_mul_pd(rho11v.v, conju); - - rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); - rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); - rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); - rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); - rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); - rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); - - /*rho0, rho1, rho2 holds final dot product - * result of 6 dcomplex elements. - */ - rho0v.d[0] += rho0v.d[2]; - rho0v.d[1] += rho0v.d[3]; + rhov[0].d[2] = rhov[1].d[0] + rhov[1].d[2]; + rhov[0].d[3] = rhov[1].d[1] + rhov[1].d[3]; - rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; - rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + rhov[1].d[0] = rhov[2].d[0] + rhov[2].d[2]; + rhov[1].d[1] = rhov[2].d[1] + rhov[2].d[3]; - rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; - rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + rhov[1].d[2] = rhov[3].d[0] + rhov[3].d[2]; + rhov[1].d[3] = rhov[3].d[1] + rhov[3].d[3]; - rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; - rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + rhov[2].d[0] = rhov[4].d[0] + rhov[4].d[2]; + rhov[2].d[1] = rhov[4].d[1] + rhov[4].d[3]; - rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; - rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + rhov[2].d[2] = rhov[5].d[0] + rhov[5].d[2]; + rhov[2].d[3] = rhov[5].d[1] + rhov[5].d[3]; - rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; - rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + /* + Computed dot product result is being stored + in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)res, rhov[0].v); + _mm256_storeu_pd((double *)(res + 2), rhov[1].v); + _mm256_storeu_pd((double *)(res + 4), rhov[2].v); + } - /*Computed dot product result is being stored - * in temp buffer r for further computation. - */ - _mm256_storeu_pd((double *)r, rho0v.v); - _mm256_storeu_pd((double *)(r+2) , rho1v.v); - _mm256_storeu_pd((double *)(r+4) , rho2v.v); + // This section will have the whole of compute when incx != 1 || inca != 1 + if (rem) + { + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur later, + // especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + + v2df_t rhov[12], a_vec[6], xv[2], conj_mul; + + rhov[0].v = _mm_setzero_pd(); + rhov[1].v = _mm_setzero_pd(); + rhov[2].v = _mm_setzero_pd(); + rhov[3].v = _mm_setzero_pd(); + rhov[4].v = _mm_setzero_pd(); + rhov[5].v = _mm_setzero_pd(); + rhov[6].v = _mm_setzero_pd(); + rhov[7].v = _mm_setzero_pd(); + rhov[8].v = _mm_setzero_pd(); + rhov[9].v = _mm_setzero_pd(); + rhov[10].v = _mm_setzero_pd(); + rhov[11].v = _mm_setzero_pd(); + + for (dim_t i = 0; i < rem; i++) + { + // Load 2 dcomplex elements from vector x + xv[0].v = _mm_loadu_pd(x_temp); - } - if(rem) - { - PRAGMA_SIMD - for(dim_t p = 0; p < 6 ; p++) - { - PASTEMAC(z,axpyjs)(a[i + p*lda] - , x[i], r[p] ); - } - } - } + // xv[1].v - R0 I0 R1 I1 => I0 I0 I1 I1 + xv[1].v = _mm_permute_pd(xv[0].v, 0b11); - if ( bli_is_conj( conjat ) ) - for ( dim_t i = 0; i < 6; ++i ) - { - PASTEMAC(z,conjs)( r[i] ); - } + // xv[0].v - R0 I0 R1 I1 => R0 R0 R1 R1 + xv[0].v = _mm_permute_pd(xv[0].v, 0b00); - /*scaling dot product result with alpha and - * adding the result to vector - */ - for ( dim_t i = 0; i < 6; ++i ) + a_vec[0].v = _mm_loadu_pd((double *)(av[0])); + a_vec[1].v = _mm_loadu_pd((double *)(av[1])); + a_vec[2].v = _mm_loadu_pd((double *)(av[2])); + a_vec[3].v = _mm_loadu_pd((double *)(av[3])); + a_vec[4].v = _mm_loadu_pd((double *)(av[4])); + a_vec[5].v = _mm_loadu_pd((double *)(av[5])); + + // perform: rho?v += a?v * xv[0]; + rhov[0].v = _mm_fmadd_pd(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[6].v = _mm_fmadd_pd(a_vec[0].v, xv[1].v, rhov[6].v); + + rhov[1].v = _mm_fmadd_pd(a_vec[1].v, xv[0].v, rhov[1].v); + rhov[7].v = _mm_fmadd_pd(a_vec[1].v, xv[1].v, rhov[7].v); + + rhov[2].v = _mm_fmadd_pd(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[8].v = _mm_fmadd_pd(a_vec[2].v, xv[1].v, rhov[8].v); + + rhov[3].v = _mm_fmadd_pd(a_vec[3].v, xv[0].v, rhov[3].v); + rhov[9].v = _mm_fmadd_pd(a_vec[3].v, xv[1].v, rhov[9].v); + + rhov[4].v = _mm_fmadd_pd(a_vec[4].v, xv[0].v, rhov[4].v); + rhov[10].v = _mm_fmadd_pd(a_vec[4].v, xv[1].v, rhov[10].v); + + rhov[5].v = _mm_fmadd_pd(a_vec[5].v, xv[0].v, rhov[5].v); + rhov[11].v = _mm_fmadd_pd(a_vec[5].v, xv[1].v, rhov[11].v); + + av[0] += 2 * inca; + av[1] += 2 * inca; + av[2] += 2 * inca; + av[3] += 2 * inca; + av[4] += 2 * inca; + av[5] += 2 * inca; + + x_temp += 2 * incx; + } + + if (bli_is_noconj(conjx_use)) + { + conj_mul.v = _mm_setr_pd(-1, 1); + } + else { - PASTEMAC(z,axpys)( *alpha, r[i], y[i] ); + conj_mul.v = _mm_setr_pd(1, -1); } + + rhov[6].v = _mm_permute_pd(rhov[6].v, 0b01); + rhov[7].v = _mm_permute_pd(rhov[7].v, 0b01); + rhov[8].v = _mm_permute_pd(rhov[8].v, 0b01); + rhov[9].v = _mm_permute_pd(rhov[9].v, 0b01); + rhov[10].v = _mm_permute_pd(rhov[10].v, 0b01); + rhov[11].v = _mm_permute_pd(rhov[11].v, 0b01); + + /* + Modifying the imag sign according to the conj value + */ + rhov[6].v = _mm_mul_pd(rhov[6].v, conj_mul.v); + rhov[7].v = _mm_mul_pd(rhov[7].v, conj_mul.v); + rhov[8].v = _mm_mul_pd(rhov[8].v, conj_mul.v); + rhov[9].v = _mm_mul_pd(rhov[9].v, conj_mul.v); + rhov[10].v = _mm_mul_pd(rhov[10].v, conj_mul.v); + rhov[11].v = _mm_mul_pd(rhov[11].v, conj_mul.v); + + rhov[0].v = _mm_add_pd(rhov[0].v, rhov[6].v); + rhov[1].v = _mm_add_pd(rhov[1].v, rhov[7].v); + rhov[2].v = _mm_add_pd(rhov[2].v, rhov[8].v); + rhov[3].v = _mm_add_pd(rhov[3].v, rhov[9].v); + rhov[4].v = _mm_add_pd(rhov[4].v, rhov[10].v); + rhov[5].v = _mm_add_pd(rhov[5].v, rhov[11].v); + + rhov[6].v = _mm_loadu_pd((double *)(res)); + rhov[7].v = _mm_loadu_pd((double *)(res + 1)); + rhov[8].v = _mm_loadu_pd((double *)(res + 2)); + rhov[9].v = _mm_loadu_pd((double *)(res + 3)); + rhov[10].v = _mm_loadu_pd((double *)(res + 4)); + rhov[11].v = _mm_loadu_pd((double *)(res + 5)); + + rhov[0].v = _mm_add_pd(rhov[0].v, rhov[6].v); + rhov[1].v = _mm_add_pd(rhov[1].v, rhov[7].v); + rhov[2].v = _mm_add_pd(rhov[2].v, rhov[8].v); + rhov[3].v = _mm_add_pd(rhov[3].v, rhov[9].v); + rhov[4].v = _mm_add_pd(rhov[4].v, rhov[10].v); + rhov[5].v = _mm_add_pd(rhov[5].v, rhov[11].v); + + /* + Computed dot product result is being stored + in temp buffer r for further computation. + */ + _mm_storeu_pd((double *)res, rhov[0].v); + _mm_storeu_pd((double *)(res + 1), rhov[1].v); + _mm_storeu_pd((double *)(res + 2), rhov[2].v); + _mm_storeu_pd((double *)(res + 3), rhov[3].v); + _mm_storeu_pd((double *)(res + 4), rhov[4].v); + _mm_storeu_pd((double *)(res + 5), rhov[5].v); + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur later, + // especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); } - else + + // Multiplying 'A' * 'x' by 'alpha' + __m256d alpha_r, alpha_i, temp_v[3]; + v4df_t rhov[3]; + + rhov[0].v = _mm256_loadu_pd((double *)(res)); + rhov[1].v = _mm256_loadu_pd((double *)(res + 2)); + rhov[2].v = _mm256_loadu_pd((double *)(res + 4)); + + if (bli_is_conj(conjat)) { - /* Query the context for the kernel function pointer. */ - const num_t dt = PASTEMAC(z,type); - PASTECH(z,dotxv_ker_ft) kfp_dv - = - bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + __m256d conj_mul = _mm256_setr_pd(1, -1, 1, -1); - for ( dim_t i = 0; i < b_n; ++i ) - { - dcomplex* restrict a1 = a + (0 )*inca + (i )*lda; - dcomplex* restrict x1 = x + (0 )*incx; - dcomplex* restrict psi1 = y + (i )*incy; + rhov[0].v = _mm256_mul_pd(rhov[0].v, conj_mul); + rhov[1].v = _mm256_mul_pd(rhov[1].v, conj_mul); + rhov[2].v = _mm256_mul_pd(rhov[2].v, conj_mul); + } - kfp_dv - ( - conjat, - conjx, - m, - alpha, - a1, inca, - x1, incx, - beta, - psi1, - cntx - ); + alpha_r = _mm256_broadcast_sd(&((*alpha).real)); + alpha_i = _mm256_broadcast_sd(&((*alpha).imag)); + + temp_v[0] = _mm256_mul_pd(rhov[0].v, alpha_i); + temp_v[1] = _mm256_mul_pd(rhov[1].v, alpha_i); + temp_v[2] = _mm256_mul_pd(rhov[2].v, alpha_i); + + temp_v[0] = _mm256_permute_pd(temp_v[0], 0b0101); + temp_v[1] = _mm256_permute_pd(temp_v[1], 0b0101); + temp_v[2] = _mm256_permute_pd(temp_v[2], 0b0101); + + rhov[0].v = _mm256_fmaddsub_pd(rhov[0].v, alpha_r, temp_v[0]); + rhov[1].v = _mm256_fmaddsub_pd(rhov[1].v, alpha_r, temp_v[1]); + rhov[2].v = _mm256_fmaddsub_pd(rhov[2].v, alpha_r, temp_v[2]); + + // When 'beta' is not zero we need to multiply scale 'y' by 'beta' + if (!PASTEMAC(z, eq0)(*beta)) + { + v4df_t yv[3]; + __m256d beta_r, beta_i; + + beta_r = _mm256_broadcast_sd(&((*beta).real)); + beta_i = _mm256_broadcast_sd(&((*beta).imag)); + + if (incy == 1) + { + yv[0].v = _mm256_loadu_pd((double *)(y)); + yv[1].v = _mm256_loadu_pd((double *)(y + 2)); + yv[2].v = _mm256_loadu_pd((double *)(y + 4)); } + else + { + /* + This can be done using SSE instructions + but has been kept as scalar code to avoid + mixing SSE with AVX + */ + yv[0].d[0] = (*(y + 0 * incy)).real; + yv[0].d[1] = (*(y + 0 * incy)).imag; + yv[0].d[2] = (*(y + 1 * incy)).real; + yv[0].d[3] = (*(y + 1 * incy)).imag; + + yv[1].d[0] = (*(y + 2 * incy)).real; + yv[1].d[1] = (*(y + 2 * incy)).imag; + yv[1].d[2] = (*(y + 3 * incy)).real; + yv[1].d[3] = (*(y + 3 * incy)).imag; + + yv[2].d[0] = (*(y + 4 * incy)).real; + yv[2].d[1] = (*(y + 4 * incy)).imag; + yv[2].d[2] = (*(y + 5 * incy)).real; + yv[2].d[3] = (*(y + 5 * incy)).imag; + } + + temp_v[0] = _mm256_mul_pd(yv[0].v, beta_i); + temp_v[1] = _mm256_mul_pd(yv[1].v, beta_i); + temp_v[2] = _mm256_mul_pd(yv[2].v, beta_i); + + temp_v[0] = _mm256_permute_pd(temp_v[0], 0b0101); + temp_v[1] = _mm256_permute_pd(temp_v[1], 0b0101); + temp_v[2] = _mm256_permute_pd(temp_v[2], 0b0101); + + yv[0].v = _mm256_fmaddsub_pd(yv[0].v, beta_r, temp_v[0]); + yv[1].v = _mm256_fmaddsub_pd(yv[1].v, beta_r, temp_v[1]); + yv[2].v = _mm256_fmaddsub_pd(yv[2].v, beta_r, temp_v[2]); + + // Here we 'rhov' has 'alpha' * 'A' * 'x' that is added with 'y' + rhov[0].v = _mm256_add_pd(yv[0].v, rhov[0].v); + rhov[1].v = _mm256_add_pd(yv[1].v, rhov[1].v); + rhov[2].v = _mm256_add_pd(yv[2].v, rhov[2].v); } + if (incy == 1) + { + _mm256_storeu_pd((double *)y, rhov[0].v); + _mm256_storeu_pd((double *)(y + 2), rhov[1].v); + _mm256_storeu_pd((double *)(y + 4), rhov[2].v); + } + else + { + (*(y + 0 * incy)).real = rhov[0].d[0]; + (*(y + 0 * incy)).imag = rhov[0].d[1]; + (*(y + 1 * incy)).real = rhov[0].d[2]; + (*(y + 1 * incy)).imag = rhov[0].d[3]; + + (*(y + 2 * incy)).real = rhov[1].d[0]; + (*(y + 2 * incy)).imag = rhov[1].d[1]; + (*(y + 3 * incy)).real = rhov[1].d[2]; + (*(y + 3 * incy)).imag = rhov[1].d[3]; + + (*(y + 4 * incy)).real = rhov[2].d[0]; + (*(y + 4 * incy)).imag = rhov[2].d[1]; + (*(y + 5 * incy)).real = rhov[2].d[2]; + (*(y + 5 * incy)).imag = rhov[2].d[3]; + } } - /** * Performs dotxf operation on scomplex. * x and y are vectors and a is the matrix. diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt deleted file mode 100644 index c9c9220609..0000000000 --- a/kernels/zen/2/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen_2 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_zen_int_amd.c - ) -target_compile_options(zen_2 PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen_2 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() -# For any other TARGET_ARCH, it would fail to configure. -# Select AMD specific sources for AMD configurations. -#[=[if(${TARGET_ARCH} STREQUAL zen OR -${TARGET_ARCH} STREQUAL zen2 OR -${TARGET_ARCH} STREQUAL zen3 OR -${TARGET_ARCH} STREQUAL zen4 OR -${TARGET_ARCH} STREQUAL amdzen) - target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_zen_int_amd.c - ) -endif()]=] diff --git a/kernels/zen/2/bli_gemv_zen_ref.c b/kernels/zen/2/bli_gemv_zen_ref.c index 0d71522c3c..0e53a5240f 100644 --- a/kernels/zen/2/bli_gemv_zen_ref.c +++ b/kernels/zen/2/bli_gemv_zen_ref.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt deleted file mode 100644 index 97a067bb64..0000000000 --- a/kernels/zen/3/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen_3 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx2_k1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_avx2_k1.c - ) -target_compile_options(zen_3 PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen_3 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() -add_subdirectory(sup) diff --git a/kernels/zen/3/bli_dgemm_avx2_k1.c b/kernels/zen/3/bli_dgemm_avx2_k1.c index b225fdad1a..e97a754ee9 100644 --- a/kernels/zen/3/bli_dgemm_avx2_k1.c +++ b/kernels/zen/3/bli_dgemm_avx2_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ #define D_MR 8 #define D_NR 6 -void bli_dgemm_8x6_avx2_k1_nn +err_t bli_dgemm_8x6_avx2_k1_nn ( dim_t m, dim_t n, @@ -58,7 +58,7 @@ void bli_dgemm_8x6_avx2_k1_nn alpha_val = *alpha; if((m == 0) || (n == 0) || (((alpha_val == 0.0) || (k == 0)) && (beta_val == 1.0))){ - return; + return BLIS_FAILURE; } dim_t m_remainder = (m % D_MR); @@ -1090,5 +1090,5 @@ void bli_dgemm_8x6_avx2_k1_nn } n_remainder = n_remainder - 2; } - return; + return BLIS_SUCCESS; } diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 477c710471..5701d8d93c 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -306,12 +306,12 @@ static err_t bli_sgemm_small bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initialization - siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + // We will use the same size to avoid pool re-initialization + siz_t buffer_size = bli_pool_block_size(bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); // Based on the available memory in the buffer we will decide if // we want to do packing or not. @@ -336,8 +336,8 @@ static err_t bli_sgemm_small printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool, if there is no pool with - // required size, it will be created. - bli_membrk_acquire_m(&rntm, + // required size, it will be created. + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -1737,7 +1737,7 @@ static err_t bli_sgemm_small #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_sgemm_small(): releasing mem pool block\n" ); #endif - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -1854,13 +1854,13 @@ err_t bli_dgemm_small bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); // Get the current size of the buffer pool for A block packing. // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); // // This kernel assumes that "A" will be unpackged if N <= 3. @@ -1885,7 +1885,7 @@ err_t bli_dgemm_small printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -2392,944 +2392,1124 @@ err_t bli_dgemm_small } m_remainder = M - row_idx; - - if (m_remainder >= 12) + if(m_remainder) { - m_remainder -= 12; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + // Sets up the mask for loading relevant remainder elements in load direction + // int64_t array of size 4 represents the mask for 4 elements of AVX2 vector register. + // + // Low end High end * Low end High end + // ________________________ * ________________________ + // | | | | | * | | | | | + // | 1 | 2 | 3 | 4 | ----> Source vector * | 1 | 2 | 3 | 4 | ----> Source vector + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // * + // ________________________ * ________________________ + // | | | | | * | | | | | + // | -1 | -1 | -1 | 0 | ----> Mask vector( mask_3 ) | -1 | -1 | 0 | 0 | ----> Mask vector( mask_2 ) + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // * + // ________________________ * ________________________ + // | | | | | * | | | | | + // | 1 | 2 | 3 | 0 | ----> Destination vector * | 1 | 2 | 0 | 0 | ----> Destination vector + // |_____|_____|_____|_____| * |_____|_____|_____|_____| + // + // -1 sets all the bits to 1. + // + dim_t m_rem = 0; + int64_t mask_4[4] = {0}; + mask_4[0] = -1; + mask_4[1] = -1; + mask_4[2] = -1; + mask_4[3] = -1; + + int64_t mask_3[4] = {0}; + mask_3[0] = -1; + mask_3[1] = -1; + mask_3[2] = -1; + mask_3[3] = 0; + + int64_t mask_2[4] = {0}; + mask_2[0] = -1; + mask_2[1] = -1; + mask_2[2] = 0; + mask_2[3] = 0; + + int64_t mask_1[4] = {0}; + mask_1[0] = -1; + mask_1[1] = 0; + mask_1[2] = 0; + mask_1[3] = 0; + + int64_t *mask_ptr[] = {mask_4, mask_1, mask_2, mask_3, mask_4}; + if(m_remainder > 12) { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) + // Handles edge cases where remainder elements are between 12-16(13, 14, 15). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, or 3 here. + m_rem = (m_remainder % 12); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - - // multiply C by beta and accumulate. - double *ttC = tC +ldc; - ymm2 = _mm256_loadu_pd(ttC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - - // multiply C by beta and accumulate. - ttC += ldc; - ymm2 = _mm256_loadu_pd(ttC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - _mm256_storeu_pd(tC + 8, ymm6); + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - tC += ldc; + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - tC += ldc; + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + ymm3 = _mm256_loadu_pd(tA + 8); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - // clear scratch registers. - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm7 = _mm256_mul_pd(ymm7, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + + // multiply C by beta and accumulate, col 2. + double* ttC = tC + ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 3. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + _mm256_storeu_pd(tC + 8, ymm6); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm7); - ymm3 = _mm256_loadu_pd(tA + 4); - ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + tC += ldc; - ymm3 = _mm256_loadu_pd(tA + 8); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm11); - tA += lda; + tC += ldc; + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - + n_remainder = N - col_idx; - if(is_beta_non_zero) + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - - double *ttC = tC + ldc; + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(ttC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - } - _mm256_storeu_pd(tC + 0, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; - tC += ldc; + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + ymm3 = _mm256_loadu_pd(tA + 8); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - // clear scratch registers. - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; + tA += lda; - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - ymm3 = _mm256_loadu_pd(tA + 4); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); - ymm3 = _mm256_loadu_pd(tA + 8); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate, col 1. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11); + + // multiply C by beta and accumulate, col 2. + double *ttC = tC + ldc; + + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(ttC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + _mm256_storeu_pd(tC + 8, ymm10); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm11); - tA += lda; + tC += ldc; + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); + col_idx += 2; } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - - if(is_beta_non_zero) + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - } - _mm256_storeu_pd(tC + 0, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - } + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - row_idx += 12; - } + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; - if (m_remainder >= 8) - { - m_remainder -= 8; + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + ymm3 = _mm256_loadu_pd(tA + 8); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); - ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8); + tA += lda; - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm2 = _mm256_loadu_pd(tC + 8); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 12, maskVec); + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + } + + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + _mm256_storeu_pd(tC + 8, ymm14); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 12, maskVec, ymm15); + } + } + else if(m_remainder > 8) + { + // Handles edge cases where remainder elements are between 9-12(9, 10, 11, 12). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, 3 or 4 here. + m_rem = (m_remainder % 8); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); - if(is_beta_non_zero) + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - ttC += ldc; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - } + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); + ymm3 = _mm256_loadu_pd(tA + 4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - tC += ldc; - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + ymm2 = _mm256_loadu_pd(tC + 4); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + + } + _mm256_storeu_pd(tC, ymm4); + _mm256_storeu_pd(tC + 4, ymm5); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm6); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; + tC += ldc; - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); + _mm256_storeu_pd(tC, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm10); - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + tC += ldc; - tA += lda; + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - - if(is_beta_non_zero) + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - - double* ttC = tC + ldc; - - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(ttC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - col_idx += 2; + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; + ymm3 = _mm256_loadu_pd(tA + 4); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); + ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + tA += lda; - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - } + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - row_idx += 8; - } + ymm2 = _mm256_loadu_pd(tC + 4); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - if (m_remainder >= 4) - { - m_remainder -= 4; + double *ttC = tC + ldc; - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); + ymm2 = _mm256_loadu_pd(ttC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; + } + _mm256_storeu_pd(tC + 0, ymm8); + _mm256_storeu_pd(tC + 4, ymm9); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm10); - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); + tC += ldc; - tA += lda; + _mm256_storeu_pd(tC, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); + + col_idx += 2; } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; - double* ttC = tC + ldc; + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + ymm3 = _mm256_loadu_pd(tA + 4); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - ttC += ldc; + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - } - _mm256_storeu_pd(tC, ymm4); + tA += lda; - tC += ldc; - _mm256_storeu_pd(tC, ymm5); + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - tC += ldc; - _mm256_storeu_pd(tC, ymm6); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm2 = _mm256_loadu_pd(tC + 4); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 8, maskVec); + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - tA += lda; + } + _mm256_storeu_pd(tC + 0, ymm12); + _mm256_storeu_pd(tC + 4, ymm13); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 8, maskVec, ymm14); } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + } + else if(m_remainder > 4) + { + // Handles edge cases where remainder elements are between 5-8(5, 6, 7, 8). + // Here m_rem gives index in mask_ptr that points which mask to be used based + // on remainder elements which could be 1, 2, 3 or 4 here. + m_rem = (m_remainder % 4); + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_rem]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; - double* ttC = tC + ldc; + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(ttC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(tC, ymm4); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - tC += ldc; - _mm256_storeu_pd(tC, ymm5); + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); - col_idx += 2; + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + // multiply C by beta and accumulate. + double *ttC = tC +ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); + + // multiply C by beta and accumulate. + ttC += ldc; + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + } + _mm256_storeu_pd(tC, ymm4); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm5); - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + tC += ldc; - ymm4 = _mm256_setzero_pd(); + _mm256_storeu_pd(tC, ymm8); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm9); - for (k = 0; k < K; ++k) + tC += ldc; + + _mm256_storeu_pd(tC, ymm12); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); + } + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + // clear scratch registers. + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; - ymm4 = _mm256_mul_pd(ymm4, ymm0); + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); + ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); + ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); + tA += lda; + + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - if(is_beta_non_zero) - { - // multiply C by beta and accumulate. - ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + //multiply A*B by alpha. + ymm8 = _mm256_mul_pd(ymm8, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); - } - _mm256_storeu_pd(tC, ymm4); - } + if(is_beta_non_zero) + { + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - row_idx += 4; - } - // M is not a multiple of 32. - // The handling of edge case where the remainder - // dimension is less than 8. The padding takes place - // to handle this case. - if ((m_remainder) && (lda > 3)) - { - double f_temp[8] = {0.0}; + double *ttC = tC + ldc; - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(ttC); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - // clear scratch registers. - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + } + _mm256_storeu_pd(tC + 0, ymm8); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm9); - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; + tC += ldc; - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); + _mm256_storeu_pd(tC, ymm12); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); - tA += lda; + col_idx += 2; } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - //multiply A*B by alpha. - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); + // clear scratch registers. + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); - if(is_beta_non_zero) - { - for (int i = 0; i < m_remainder; i++) + for (k = 0; k < K; ++k) { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + tB += tb_inc_row; + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm3 = _mm256_loadu_pd(tA); + ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - double* ttC = tC + ldc; + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + tA += lda; - ttC += ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - tC += ldc; - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } + //multiply A*B by alpha. + ymm12 = _mm256_mul_pd(ymm12, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); - tC += ldc; - _mm256_storeu_pd(f_temp, ymm9); - for (int i = 0; i < m_remainder; i++) + if(is_beta_non_zero) { - tC[i] = f_temp[i]; - } + // multiply C by beta and accumulate. + ymm2 = _mm256_loadu_pd(tC + 0); + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC + 4, maskVec); + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + } + _mm256_storeu_pd(tC + 0, ymm12); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC + 4, maskVec, ymm13); + } } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) + else { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) + __m256i maskVec = _mm256_loadu_si256( (__m256i *)mask_ptr[m_remainder]); + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); - tA += lda; - } + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); + tB += tb_inc_row; - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; + //broadcasted matrix B elements are multiplied + //with matrix A columns. - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); + ymm6 = _mm256_mul_pd(ymm6, ymm0); - if(is_beta_non_zero) - { - for (int i = 0; i < m_remainder; i++) + if(is_beta_non_zero) { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - double* ttC = tC + ldc; + double* ttC = tC + ldc; - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = ttC[i]; + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + + ttC += ldc; + + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7); + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); + + tC += ldc; + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm5); + tC += ldc; + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm6); } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) + n_remainder = N - col_idx; + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) { - tC[i] = f_temp[i]; - } + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; - tC += ldc; - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); + ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); + tB += tb_inc_row; - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; + //broadcasted matrix B elements are multiplied + //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - tA += lda; - } + tA += lda; + } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; + //multiply A*B by alpha. + ymm4 = _mm256_mul_pd(ymm4, ymm0); + ymm5 = _mm256_mul_pd(ymm5, ymm0); - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + if(is_beta_non_zero) + { + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm0 = _mm256_broadcast_sd(alpha_cast); - ymm1 = _mm256_broadcast_sd(beta_cast); + double* ttC = tC + ldc; - // multiply C by beta and accumulate. - ymm5 = _mm256_mul_pd(ymm5, ymm0); + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(ttC, maskVec); + ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); + } + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); - if(is_beta_non_zero) - { + tC += ldc; + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm5); - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - } - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - m_remainder = 0; - } + col_idx += 2; - if (m_remainder) - { - double result; - for (; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) { //pointer math to point to proper memory tC = C + ldc * col_idx + row_idx; tB = B + tb_inc_col * col_idx; tA = A + row_idx; - result = 0; + ymm4 = _mm256_setzero_pd(); + for (k = 0; k < K; ++k) { - result += (*tA) * (*tB); - tA += lda; + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); tB += tb_inc_row; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + + // Masked load the relevant remainder elements only + // using maskVec. + ymm3 = _mm256_maskload_pd(tA, maskVec); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tA += lda; } + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_sd(alpha_cast); + ymm1 = _mm256_broadcast_sd(beta_cast); + + ymm4 = _mm256_mul_pd(ymm4, ymm0); - result *= (*alpha_cast); if(is_beta_non_zero) - (*tC) = (*tC) * (*beta_cast) + result; - else - (*tC) = result; + { + // Masked load the relevant remainder elements of C matrix + // Scale by beta. + ymm2 = _mm256_maskload_pd(tC, maskVec); + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); + + } + // Masked store the relevant remainder elements of C matrix + _mm256_maskstore_pd(tC, maskVec, ymm4); } } } - // Return the buffer to pool + // Return the buffer to pool if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dgemm_small(): releasing mem pool block\n" ); #endif - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); @@ -4370,13 +4550,13 @@ err_t bli_dgemm_small_At bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); // Get the current size of the buffer pool for A block packing. // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); // // This kernel assumes that "A" will be unpackged if N <= 3. @@ -4401,7 +4581,7 @@ err_t bli_dgemm_small_At printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -5708,7 +5888,7 @@ err_t bli_dgemm_small_At #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); #endif - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); @@ -5855,13 +6035,13 @@ err_t bli_zgemm_small bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); // Get the current size of the buffer pool for A block packing. // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); // // This kernel assumes that "A" will be unpackged if N <= 3. @@ -5883,7 +6063,7 @@ err_t bli_zgemm_small buffer_size); #endif // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -9694,7 +9874,7 @@ err_t bli_zgemm_small #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_zgemm_small(): releasing mem pool block\n" ); #endif - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); @@ -9819,13 +9999,13 @@ err_t bli_zgemm_small_At bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); // Get the current size of the buffer pool for A block packing. // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_pba_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_pba(&rntm))); // // This kernel assumes that "A" will be unpackged if N <= 3. @@ -9851,7 +10031,7 @@ err_t bli_zgemm_small_At buffer_size); #endif // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -13396,7 +13576,7 @@ err_t bli_zgemm_small_At #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_zgemm_small_At(): releasing mem pool block\n" ); #endif - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); diff --git a/kernels/zen/3/bli_gemm_tiny.c b/kernels/zen/3/bli_gemm_tiny.c new file mode 100644 index 0000000000..bf6ffa5cc2 --- /dev/null +++ b/kernels/zen/3/bli_gemm_tiny.c @@ -0,0 +1,629 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "xmmintrin.h" +#include "blis.h" + +static dgemmsup_ker_ft kern_fp[] = +{ + bli_dgemmsup_rv_haswell_asm_6x8m, + bli_dgemmsup_rd_haswell_asm_6x8m, + bli_dgemmsup_rv_haswell_asm_6x8m, + bli_dgemmsup_rv_haswell_asm_6x8n, + bli_dgemmsup_rv_haswell_asm_6x8m, + bli_dgemmsup_rd_haswell_asm_6x8n, + bli_dgemmsup_rv_haswell_asm_6x8n, + bli_dgemmsup_rv_haswell_asm_6x8n +}; + +#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) +static err_t bli_dgemm_tiny_24x8_kernel + ( + conj_t conja, + conj_t conjb, + trans_t transa, + trans_t transb, + dim_t m, + dim_t n, + dim_t k, + const double* alpha, + const double* a, const inc_t rs_a0, const inc_t cs_a0, + const double* b, const inc_t rs_b0, const inc_t cs_b0, + const double* beta, + double* c, const inc_t rs_c0, const inc_t cs_c0 + ) +{ + double *a_local = (double *)a; + double *b_local = (double *)b; + double *c_local = (double *)c; + guint_t cs_a = cs_a0; + guint_t rs_a = rs_a0; + guint_t cs_b = cs_b0; + guint_t rs_b = rs_b0; + guint_t cs_c = cs_c0; + guint_t rs_c = rs_c0; + inc_t rs_a_local = rs_a0; + inc_t cs_a_local = cs_a0; + inc_t rs_b_local = rs_b0; + inc_t cs_b_local = cs_b0; + inc_t rs_c_local = rs_c0; + inc_t cs_c_local = cs_c0; + + gint_t M = m; + gint_t N = n; + gint_t K = k; + + inc_t storage = 0; + if(transb == BLIS_NO_TRANSPOSE || transb == BLIS_CONJ_NO_TRANSPOSE) + { + storage = 1 * (rs_b == 1); //1st bit + } + else if(transb == BLIS_TRANSPOSE || transb == BLIS_CONJ_TRANSPOSE) + { + storage = 1 * (cs_b == 1); //1st bit + rs_b = cs_b0; + cs_b = rs_b0; + } + + if(transa == BLIS_NO_TRANSPOSE || transa == BLIS_CONJ_NO_TRANSPOSE) + { + storage |= ((1 * (rs_a == 1)) << 1); //2nd bit + } + else if(transa == BLIS_TRANSPOSE || transa == BLIS_CONJ_TRANSPOSE) + { + storage |= ((1 * (cs_a == 1)) << 1); //2nd bit + rs_a = cs_a0; + cs_a = rs_a0; + } + + storage |= ((1 * (rs_c == 1)) << 2); //3rd bit + + stor3_t stor_id = (stor3_t) storage; + + //Early return, since we do not support dot product gemm kernels. + if(stor_id == BLIS_CRC || stor_id == BLIS_RRC) + { + return BLIS_FAILURE; + } + + const bool is_rrr_rrc_rcr_crr = ( + stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR + ); + + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + const bool row_pref = false; + const bool col_pref = !row_pref; + + const bool is_primary = ( row_pref && is_rrr_rrc_rcr_crr ) || + ( col_pref && is_rcc_crc_ccr_ccc ); + + /** + * Based on matrix storage scheme and kernel preference, + * decision is made here that whether it is primary storage + * scheme or not. + */ + if ( !is_primary ) + { + /** + * For non-primary storage scheme, we configure parameters, + * for kernel re-use. + */ + a_local = (double *)b; + b_local = (double *)a; + rs_a_local = cs_b; + cs_a_local = rs_b; + rs_b_local = cs_a; + cs_b_local = rs_a; + rs_c_local = cs_c0; + cs_c_local = rs_c0; + M = n; + N = m; + + rs_a = rs_a_local; + cs_a = cs_a_local; + cs_c = cs_c_local; + rs_b = rs_b_local; + cs_b = cs_b_local; + rs_c = rs_c_local; + } + + double *A = a_local; + double *B = b_local; + double *C = c_local; + double *alpha_cast; + double beta_cast = *beta; + double one_local = 1.0; + alpha_cast = (double *)alpha; + /** + * Set blocking and micro tile parameters before computing + */ + const dim_t MC = 144; + const dim_t KC = 480; + const dim_t MR_ = 24; + const dim_t NR_ = 8; + /** + * MC must be in multiple of MR_. + * if not return early. + */ + if( MC % MR_ != 0 ) + { + return BLIS_FAILURE; + } + dim_t n_rem = N % NR_; + dim_t m_part_rem = M % MC; + dim_t k_rem = K % KC; + dim_t n_cur = 0; + dim_t m_cur = 0; + dim_t k_cur = 0; + dim_t k_iter = 0; + auxinfo_t aux; + inc_t ps_a_use = (MR_ * rs_a); + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + /** + * JC Loop is eliminated as it iterates only once, So computation + * can start from K loop. + * Here K loop is divided into two parts to avoid repetitive check for Beta. + * For first iteration, it will use Beta to scale C matrix. + * Subsequent iterations will scale C matrix by 1. + */ + k_iter = 0; //1st k loop, scale C matrix by beta + k_cur = (KC <= K ? KC : k_rem); + for ( dim_t m_iter = 0; m_iter < M; m_iter += MC) + { + m_cur = (MC <= (M - m_iter) ? MC : m_part_rem); + for ( dim_t jr_iter = 0; jr_iter < N; jr_iter += NR_ ) + { + n_cur = (NR_ <= (N - jr_iter) ? NR_ : n_rem); + bli_dgemmsup_rv_zen4_asm_24x8m(conja, + conjb, + m_cur, + n_cur, + k_cur, + alpha_cast, + (A + (m_iter * rs_a) + (k_iter * cs_a)), /*A matrix offset*/ + rs_a, + cs_a, + (B + (jr_iter * cs_b) + (k_iter * rs_b)), /*B matrix offset*/ + rs_b, + cs_b, + &beta_cast, + (C + jr_iter * cs_c + m_iter * rs_c), /*C matrix offset*/ + rs_c, + cs_c, + &aux, + NULL); + } + } + // k_iter = KC loop where C matrix is scaled by one. Beta is one. + for (k_iter = KC ; k_iter < K; k_iter += KC ) + { + k_cur = (KC <= (K - k_iter) ? KC : k_rem); + for ( dim_t m_iter = 0; m_iter < M; m_iter += MC) + { + m_cur = (MC <= (M - m_iter) ? MC : m_part_rem); + for ( dim_t jr_iter = 0; jr_iter < N; jr_iter += NR_ ) + { + n_cur = (NR_ <= (N - jr_iter) ? NR_ : n_rem); + bli_dgemmsup_rv_zen4_asm_24x8m(conja, + conjb, + m_cur, + n_cur, + k_cur, + alpha_cast, + (A + (m_iter * rs_a) + (k_iter * cs_a)), /*A matrix offset*/ + rs_a, + cs_a, + (B + (jr_iter * cs_b) + (k_iter * rs_b)), /*B matrix offset*/ + rs_b, + cs_b, + &one_local, + (C + jr_iter * cs_c + m_iter * rs_c), /*C matrix offset*/ + rs_c, + cs_c, + &aux, + NULL); + } + } + } + + return BLIS_SUCCESS; +} +#endif + +static err_t bli_dgemm_tiny_6x8_kernel + ( + conj_t conja, + conj_t conjb, + trans_t transa, + trans_t transb, + dim_t m, + dim_t n, + dim_t k, + const double* alpha, + const double* a, const inc_t rs_a0, const inc_t cs_a0, + const double* b, const inc_t rs_b0, const inc_t cs_b0, + const double* beta, + double* c, const inc_t rs_c0, const inc_t cs_c0 + ) +{ + double *a_local = (double *)a; + double *b_local = (double *)b; + double *c_local = (double *)c; + guint_t cs_a = cs_a0; + guint_t rs_a = rs_a0; + guint_t cs_b = cs_b0; + guint_t rs_b = rs_b0; + guint_t cs_c = cs_c0; + guint_t rs_c = rs_c0; + inc_t rs_a_local = rs_a0; + inc_t cs_a_local = cs_a0; + inc_t rs_b_local = rs_b0; + inc_t cs_b_local = cs_b0; + inc_t rs_c_local = rs_c0; + inc_t cs_c_local = cs_c0; + + gint_t M = m; + gint_t N = n; + gint_t K = k; + + inc_t storage = 0; + if(transb == BLIS_NO_TRANSPOSE || transb == BLIS_CONJ_NO_TRANSPOSE) + { + storage = 1 * (rs_b == 1); //1st bit + } + else if(transb == BLIS_TRANSPOSE || transb == BLIS_CONJ_TRANSPOSE) + { + storage = 1 * (cs_b == 1); //1st bit + rs_b = cs_b0; + cs_b = rs_b0; + } + + if(transa == BLIS_NO_TRANSPOSE || transa == BLIS_CONJ_NO_TRANSPOSE) + { + storage |= ((1 * (rs_a == 1)) << 1); //2nd bit + } + else if(transa == BLIS_TRANSPOSE || transa == BLIS_CONJ_TRANSPOSE) + { + storage |= ((1 * (cs_a == 1)) << 1); //2nd bit + rs_a = cs_a0; + cs_a = rs_a0; + } + + storage |= ((1 * (rs_c == 1)) << 2); //3rd bit + + /** + * typecast storage into stor_idd, + * stores default storage scheme before we optimze + * for respective gemm kernel. */ + stor3_t stor_idd = (stor3_t) storage; + stor3_t stor_id = 0; + + stor_id = stor_idd; + + const bool is_rrr_rrc_rcr_crr = ( + stor_idd == BLIS_RRR || + stor_idd == BLIS_RRC || + stor_idd == BLIS_RCR || + stor_idd == BLIS_CRR + ); + + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + const bool row_pref = true; + const bool col_pref = !row_pref; + + /** + * Based on matrix storage scheme and kernel preference, + * decision is made here that whether it is primary storage + * scheme or not. + */ + const bool is_primary = ( row_pref && is_rrr_rrc_rcr_crr ) || + ( col_pref && is_rcc_crc_ccr_ccc ); + + /** + * For non-primary storage scheme, we configure parameters, + * for kernel re-use. + */ + if ( !is_primary ) + { + a_local = (double *)b; + b_local = (double *)a; + rs_a_local = cs_b; + cs_a_local = rs_b; + rs_b_local = cs_a; + cs_b_local = rs_a; + rs_c_local = cs_c0; + cs_c_local = rs_c0; + M = n; + N = m; + + stor_id = bli_stor3_trans(stor_idd); + + rs_a = rs_a_local; + cs_a = cs_a_local; + cs_c = cs_c_local; + rs_b = rs_b_local; + cs_b = cs_b_local; + rs_c = rs_c_local; + } + + double *A = a_local; + double *B = b_local; + double *C = c_local; + double *alpha_cast; + double beta_cast = *beta; + double one_local = 1.0; + + alpha_cast = (double *)alpha; + /** + * Set blocking and micro tile parameters before computing + */ + const dim_t MC = 72; + const dim_t KC = 256; + const dim_t MR_ = 6; + const dim_t NR_ = 8; + + + /** + * MC must be in multiple of MR_. + * if not return early. + */ + if( MC % MR_ != 0 ) + { + return BLIS_FAILURE; + } + dim_t n_rem = N % NR_; + dim_t m_part_rem = M % MC; + dim_t k_rem = K % KC; + dim_t n_cur = 0; + dim_t m_cur = 0; + dim_t k_cur = 0; + dim_t k_iter = 0; + + auxinfo_t aux; + inc_t ps_a_use = (MR_ * rs_a); + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + dgemmsup_ker_ft kern_ptr = kern_fp[stor_id]; + + /** + * JC Loop is eliminated as it iterates only once, So computation + * can start from K loop. + * Here K loop is divided into parts to avoid repetitive check for Beta. + * For first iteration, it will use Beta to scale C matrix. + * Subsequent iterations will scale C matrix by 1. + */ + k_iter = 0; //1st k loop, scale C matrix by beta + k_cur = (KC <= K ? KC : k_rem); + for ( dim_t m_iter = 0; m_iter < M; m_iter += MC) + { + m_cur = (MC <= (M - m_iter) ? MC : m_part_rem); + for ( dim_t jr_iter = 0; jr_iter < N; jr_iter += NR_ ) + { + n_cur = (NR_ <= (N - jr_iter) ? NR_ : n_rem); + kern_ptr(conja, + conjb, + m_cur, + n_cur, + k_cur, + alpha_cast, + (A + (m_iter * rs_a) + (k_iter * cs_a)), /*A matrix offset*/ + rs_a, + cs_a, + (B + (jr_iter * cs_b) + (k_iter * rs_b)), /*B matrix offset*/ + rs_b, + cs_b, + &beta_cast, + (C + (jr_iter * cs_c) + (m_iter * rs_c)), /*C matrix offset*/ + rs_c, + cs_c, + &aux, + NULL); + } + } + // k_iter = KC loop where C matrix is scaled by one. Beta is one. + for (k_iter = KC; k_iter < K; k_iter += KC ) + { + k_cur = (KC <= (K - k_iter) ? KC : k_rem); + for ( dim_t m_iter = 0; m_iter < M; m_iter += MC) + { + m_cur = (MC <= (M - m_iter) ? MC : m_part_rem); + for ( dim_t jr_iter = 0; jr_iter < N; jr_iter += NR_ ) + { + n_cur = (NR_ <= (N - jr_iter) ? NR_ : n_rem); + kern_ptr(conja, + conjb, + m_cur, + n_cur, + k_cur, + alpha_cast, + (A + (m_iter * rs_a) + (k_iter * cs_a)), /*A matrix offset*/ + rs_a, + cs_a, + (B + (jr_iter * cs_b) + (k_iter * rs_b)), /*B matrix offset*/ + rs_b, + cs_b, + &one_local, + (C + (jr_iter * cs_c) + (m_iter * rs_c)), /*C matrix offset*/ + rs_c, + cs_c, + &aux, + NULL); + } + } + } + + return BLIS_SUCCESS; +} + +static arch_t get_arch_id(void) +{ + static arch_t arch_id = BLIS_NUM_ARCHS + 1; + if(arch_id == BLIS_NUM_ARCHS + 1) + { + arch_id = bli_cpuid_query_id(); + } + + return arch_id; +} + +err_t bli_dgemm_tiny +( + trans_t transa, + trans_t transb, + dim_t m, + dim_t n, + dim_t k, + const double* alpha, + const double* a, const inc_t rs_a0, const inc_t cs_a0, + const double* b, const inc_t rs_b0, const inc_t cs_b0, + const double* beta, + double* c, const inc_t rs_c0, const inc_t cs_c0 +) +{ + arch_t arch_id = get_arch_id(); + //for the below tiny sizes of matrix, we force it to be ST compute. + if( + m <= 24 && n <= 24 && k <= 20 && + (BLIS_ARCH_ZEN == arch_id || + BLIS_ARCH_ZEN2 == arch_id || + BLIS_ARCH_ZEN3 == arch_id || + BLIS_ARCH_ZEN4 == arch_id) + ) + { + bool ret = bli_aocl_enable_instruction_query(); + if((ret == FALSE) || + (arch_id != BLIS_ARCH_ZEN4) + ) + { + return bli_dgemm_tiny_6x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } +#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) + else if(BLIS_ARCH_ZEN4 == arch_id) + { + return bli_dgemm_tiny_24x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } +#endif + } + if(FALSE == bli_thread_get_is_parallel()) + { + if( + BLIS_ARCH_ZEN == arch_id || + BLIS_ARCH_ZEN2 == arch_id || + BLIS_ARCH_ZEN3 == arch_id + ) + { + if( ( (m <= 8) || ( (m <= 1000) && (n <= 24) && (k >= 4) ) ) && (k <= 1500) ) + { + return bli_dgemm_tiny_6x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } + } +#if defined(BLIS_FAMILY_ZEN4) || defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_X86_64) + else if(BLIS_ARCH_ZEN4 == arch_id) + { + if(((m == n) && (m < 400) && (k < 1000)) || + ( (m != n) && (( ((m + n -k) < 1500) && + ((m + k-n) < 1500) && ((n + k-m) < 1500) ) || + ((n <= 100) && (k <=100))))) + { + return bli_dgemm_tiny_24x8_kernel + ( + 1 * (transa == BLIS_CONJ_NO_TRANSPOSE), + 1 * (transb == BLIS_CONJ_NO_TRANSPOSE), + transa, + transb, + m, + n, + k, + alpha, + a, rs_a0, cs_a0, + b, rs_b0, cs_b0, + beta, + c, rs_c0, cs_c0 + ); + } + } +#endif + else + { + ;//Return failure + } + } + + return BLIS_FAILURE; +} diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index d08dbb2279..0fd06c86f5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -8308,12 +8308,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -8321,7 +8321,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -10628,7 +10628,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -10718,12 +10718,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -10731,7 +10731,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -12984,7 +12984,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -13064,12 +13064,12 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if((d_mr * m * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -13077,7 +13077,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB if(required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -14961,7 +14961,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm,&local_mem_buf_A_s); + bli_pba_release(&rntm,&local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -15076,12 +15076,12 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -15089,7 +15089,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -17062,7 +17062,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -17664,12 +17664,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(float)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -17677,7 +17677,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -21242,7 +21242,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -21333,12 +21333,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(float)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -21346,7 +21346,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -25083,7 +25083,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -25203,12 +25203,12 @@ BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(float)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -25216,7 +25216,7 @@ BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -29501,7 +29501,7 @@ BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -29583,12 +29583,12 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if((d_mr * m * sizeof(float)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -29596,7 +29596,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB if(required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -33658,7 +33658,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm,&local_mem_buf_A_s); + bli_pba_release(&rntm,&local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -33735,12 +33735,12 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(dcomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -33748,7 +33748,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -34900,7 +34900,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -34977,12 +34977,12 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if((d_mr * m * sizeof(dcomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -34990,7 +34990,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB if(required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -36131,7 +36131,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -36189,12 +36189,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(dcomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -36202,7 +36202,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -37597,7 +37597,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -37656,12 +37656,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if( (d_nr * n * sizeof(dcomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -37669,7 +37669,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -39039,7 +39039,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -42281,12 +42281,12 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(scomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -42294,7 +42294,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -44735,7 +44735,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -44819,12 +44819,12 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(scomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -44832,7 +44832,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -47521,7 +47521,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -47600,12 +47600,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(scomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -47613,7 +47613,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -49144,7 +49144,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } @@ -49223,12 +49223,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_pba_rntm_set_pba( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ( (d_mr * m * sizeof(scomplex)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -49236,7 +49236,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -50795,7 +50795,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } diff --git a/kernels/zen/3/bli_zgemm_avx2_k1.c b/kernels/zen/3/bli_zgemm_avx2_k1.c index a6a92f9a54..669afcfcfe 100644 --- a/kernels/zen/3/bli_zgemm_avx2_k1.c +++ b/kernels/zen/3/bli_zgemm_avx2_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,81 +33,64 @@ */ #include -#include #include "blis.h" #include "immintrin.h" #define Z_MR 4 -#define Z_NR 6 - -// Macros for the main loop for M -#define SCALE_ALPHA_REAL_M_LOOP(rin_0,rin_1,r_bcast,real_val) \ - r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ - rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ - rin_1 = _mm256_mul_pd(rin_1,r_bcast); \ - -#define SCALE_ALPHA_IMAG_M_LOOP(rout_0,rout_1,rin_0,rin_1,r_bcast,r_perm,imag_val) \ - r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ - r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r_perm = _mm256_mul_pd(r_bcast, r_perm); \ - r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ - rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ - r_perm = _mm256_permute4x64_pd(rin_1,0b10110001); \ - r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r_perm = _mm256_mul_pd(r_bcast, r_perm); \ - r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ - rout_1 = _mm256_fmadd_pd(r_perm,r_bcast,rout_1); \ - -#define NEG_PERM_M_LOOP(r0,r1,r2) \ - r0 = _mm256_permute4x64_pd(r0,0b10110001); \ - r1 = _mm256_permute4x64_pd(r1,0b10110001); \ - r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r0 = _mm256_mul_pd(r2, r0); \ - r1 = _mm256_mul_pd(r2, r1); \ - -#define FMA_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,loc) \ - rbc = _mm256_broadcast_sd(loc); \ - rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ - -#define SCALE_BETA_REAL_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc) \ - rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ - -#define SCALE_BETA_IMAG_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,rn) \ - NEG_PERM_M_LOOP(rin_0,rin_1,rn); \ - rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ - -// Macros for fringe cases with M -#define SCALE_ALPHA_REAL_M_FRINGE(rin_0,r_bcast,real_val) \ - r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ - rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ - -#define SCALE_ALPHA_IMAG_M_FRINGE(rout_0,rin_0,r_bcast,r_perm,imag_val) \ - r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ - r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r_perm = _mm256_mul_pd(r_bcast, r_perm); \ - r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ - rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ - -#define NEG_PERM_M_FRINGE(r0,r2) \ - r0 = _mm256_permute4x64_pd(r0,0b10110001); \ - r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r0 = _mm256_mul_pd(r2, r0); \ - -#define FMA_M_FRINGE(r_in,r_out,r_bc,loc) \ - r_bc = _mm256_broadcast_sd(loc); \ - r_out = _mm256_fmadd_pd(r_bc, r_in, r_out); \ - -#define SCALE_BETA_REAL_M_FRINGE(rin_0,rout_0,rbc) \ - rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - -#define SCALE_BETA_IMAG_M_FRINGE(rin_0,rout_0,rbc,rn) \ - NEG_PERM_M_FRINGE(rin_0,rn); \ - rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - -void bli_zgemm_4x6_avx2_k1_nn +#define Z_NR 4 + +// Macro to be used for beta scaling with 2 loads from C(main loop of m) +#define BETA_SCALING_C_MAIN(reg_0, reg_1, loc) \ +\ + /* Here, a_vec_0 and a_vec_1 are used to load columns of + length Z_MR from C, with bdcst_0 and bdcst_1 already + having the real and imaginary parts of beta broadcasted + onto them. reg_0 and reg_1 are the intermediate registers + containing the result of alpha*A*B on them. The beta scaling + and final accumalation is done on these registers for + storing the corresponding column of C. */ \ +\ + a_vec_0 = _mm256_loadu_pd((double const*)(loc)); \ + a_vec_1 = _mm256_loadu_pd((double const*)(loc + 2)); \ +\ + reg_0 = _mm256_fmadd_pd(a_vec_0, bdcst_0, reg_0); \ + reg_1 = _mm256_fmadd_pd(a_vec_1, bdcst_0, reg_1); \ +\ + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); \ + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); \ +\ + a_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_1); \ + a_vec_1 = _mm256_mul_pd(a_vec_1, bdcst_1); \ +\ + reg_0 = _mm256_addsub_pd(reg_0, a_vec_0); \ + reg_1 = _mm256_addsub_pd(reg_1, a_vec_1); + +// Macro to be used for beta scaling with 1 load from C(fringe case with m_rem == 1) +#define BETA_SCALING_C_FRINGE(reg_0, loc) \ +\ + /* Here, a_vec_0 is used to load a column of length 2 + from C, with bdcst_0 and bdcst_1 already having the real + and imaginary parts of beta broadcasted onto them. reg_0 + is the intermediate register containing the result of + alpha*A*B on it. The beta scaling and final accumalation + is done on these registers for storing the corresponding + column of C. */ \ +\ + a_vec_0 = _mm256_loadu_pd((double const*)(loc)); \ +\ + reg_0 = _mm256_fmadd_pd(a_vec_0, bdcst_0, reg_0); \ +\ + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); \ +\ + a_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_1); \ +\ + reg_0 = _mm256_addsub_pd(reg_0, a_vec_0); + +/* The following API implements the ZGEMM operation specifically for inputs A and B + with k == 1. It expects the inputs and output to support the column-major storage + scheme, without any requirement to conjugate/transpose any of the operands. */ + +void bli_zgemm_4x4_avx2_k1_nn ( dim_t m, dim_t n, @@ -119,1709 +102,1026 @@ void bli_zgemm_4x6_avx2_k1_nn dcomplex* c, const inc_t ldc ) { + // Setting the required variables for choosing the right path + // to execute the required computation. + dim_t m_iter = ( m / Z_MR ); + dim_t n_iter = ( n / Z_NR ); - double alpha_real, beta_real; - double alpha_imag, beta_imag; - - alpha_real = alpha->real; - beta_real = beta->real; - alpha_imag = alpha->imag; - beta_imag = beta->imag; - - /* If m or n is zero, return immediately. */ - if ( bli_zero_dim2( m, n ) ) return; - /* If alpha alone is zero, scale by beta and return. */ - if (bli_zeq0(*(alpha))) - { - bli_zscalm( - BLIS_NO_CONJUGATE, - 0, - BLIS_NONUNIT_DIAG, - BLIS_DENSE, - m, n, - beta, - c, 1, ldc - ); - return; - } - - dim_t m_remainder = (m % Z_MR); - dim_t n_remainder = (n % Z_NR); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m128d xmm5; - - //gcc12 throws a unitialized warning, - //To avoid that these variable are set to zero. - ymm0 = _mm256_setzero_pd(); - /* Form C = alpha*A*B + beta*c */ - // Main loop along N dimension - for(dim_t j = 0;j < (n-Z_NR+1);j=j+Z_NR) + dim_t m_remainder = ( m % Z_MR ); + dim_t n_remainder = ( n % Z_NR ); + + // Setting the alpha and beta scaling components(real and imaginary). + double alpha_real = alpha->real; + double alpha_imag = alpha->imag; + + double beta_real = beta->real; + double beta_imag = beta->imag; + + // Using the predefined enumerated constants to classify beta scaling + // into one of the below categories. + int beta_mul_type = BLIS_MUL_DEFAULT; + + // Setting the appropriate type for beta scaling + // based on any of the special cases. + if( beta_imag == 0.0 ) { - dcomplex* temp_b = b + j*ldb; - dcomplex* temp_a = a; - dcomplex* temp_c = c + j*ldc; + if( beta_real == 0.0 ) beta_mul_type = BLIS_MUL_ZERO; + else if( beta_real == 1.0 ) beta_mul_type = BLIS_MUL_ONE; + } + + // Implementing the GEMM operation, which is as follows : + // C := beta*C + alpha*A*B. + + // The code structure deals with fringe cases first, followed by the main loop + // both in the n and m direction. - //Main loop along M dimension - for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + // Local pointers for B and C, to be used along the n-loop + dcomplex* temp_b = b; + dcomplex* temp_c = c; + + if( ( n_remainder & 0x1 ) == 1 ) // In case of n_remainder being 1 or 3 + { + // Setting the panel addresses for A, B and C, to be used along m-loop + dcomplex *temp_ai = a; + dcomplex *temp_bj = temp_b; + dcomplex *temp_cij = temp_c; + + /* Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x 1 block of B in order to compute the associated Z_MR x 1 + and/or m_remainder x 1 block of C. This reusability has been exploited, wherein + the associated 1 x 1 block of B is scaled with alpha, and stored in + registers beforehand, to be reused in the main loop or fringe case of m. */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m256d a_vec_0, a_vec_1; + __m256d b_vec_0; + __m256d b_real_0; + __m256d b_imag_0; + __m256d bdcst_0, bdcst_1; + + /* Broadcasting real and imaginary components of elements from B + and unpacking them to set them in registers in the form : + { Real_part, Imag_part, Real_part, Imag_part }. + + A total of Z_NR registers are used to store the alpha-scaled B + for reuse. */ + + b_real_0 = _mm256_broadcast_sd((double const *)(temp_bj)); + b_imag_0 = _mm256_broadcast_sd((double const *)(temp_bj) + 1); + b_vec_0 = _mm256_unpacklo_pd(b_real_0, b_imag_0); + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec_0 = _mm256_broadcast_sd((double const *)(&alpha_real)); + a_vec_1 = _mm256_broadcast_sd((double const *)(&alpha_imag)); + + bdcst_0 = _mm256_unpacklo_pd(b_imag_0, b_real_0); + bdcst_0 = _mm256_mul_pd(a_vec_1, bdcst_0); + b_vec_0 = _mm256_fmaddsub_pd(a_vec_0, b_vec_0, bdcst_0); + + // Fringe cases in the m-direction. + dim_t m_rem = m_remainder; + if ( ( m_rem & 0x1 ) == 1 ) { - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag - where alpha_real and/or alpha_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - _mm_prefetch((char*)(temp_a + 32), _MM_HINT_T0); - - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); - - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) - I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] - R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) - I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - //ymm4+=R(b[0][0])*R(a[2][0]) R(b[0][0])*I(a[2][0]) - // R(b[0][0])*R(a[3][0]) R(b[0][0])*I(a[3][0]) - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - //ymm6+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - //ymm8+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - //ymm10+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - //ymm12+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)); - //ymm11+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) - // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) - //ymm12+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) - // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - //ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - //ymm4+=R(b[0][0])*R(a[2][0]) I(b[0][0])*I(a[2][0]) - // I(b[0][0])*R(a[3][0]) I(b[0][0])*I(a[3][0]) - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - //ymm6+=R(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - //ymm8+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - //ymm10+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - //ymm12+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)+1); - //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) - // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) - //ymm14+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) - // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)+1); - - /* - a. Perform beta*C using temp_c, beta_real, - where beta_real is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) - // beta_real*R(c[3][0]) beta_real*I(c[3][0]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) - // beta_real*R(c[3][1]) beta_real*I(c[3][1]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - // beta_real*R(c[1][2]) beta_real*I(c[1][2]) - //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) - //beta_real*R(c[3][2]) beta_real*I(c[3][2]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - // beta_real*R(c[1][3]) beta_real*I(c[1][3]) - //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) - // beta_real*R(c[3][3]) beta_real*I(c[3][3]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); - - //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); - //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) - // beta_real*R(c[1][4]) beta_real*I(c[1][4]) - //ymm12+=beta_real*R(c[2][4]) beta_real*I(c[2][4]) - // beta_real*R(c[3][4]) beta_real*I(c[3][4]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15); - - //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); - //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); - //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) - // beta_real*R(c[1][5]) beta_real*I(c[1][5]) - //ymm14+=beta_real*R(c[2][5]) beta_real*I(c[2][5]) - // beta_real*R(c[3][5]) beta_real*I(c[3][5]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15); - } + // Scratch registers. + __m256d b_scaled_0, b_perm_0, a_real, a_imag; - /* - a. Perform beta*C using temp_c, beta_imag, - where beta_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) - // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) - // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) - //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) - // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) - //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) - // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); - - //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); - //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) - // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) - //ymm12+=beta_imag*(-I(c[2][4])) beta_imag*R(c[2][4]) - // beta_imag*(-I(c[3][4])) beta_imag*R(c[3][4]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15,ymm2); - - //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); - //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); - //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) - // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) - //ymm14+=beta_imag*(-I(c[2][5])) beta_imag*R(c[2][5]) - // beta_imag*(-I(c[3][5])) beta_imag*R(c[3][5]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2); - } - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + __m128d b_element_0, c_element_0; + __m128d beta_real_reg, beta_imag_reg, c_perm_0; - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ + b_scaled_0 = _mm256_setzero_pd(); + b_perm_0 = _mm256_setzero_pd(); - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + /* Here, only a single element from A is of concern. + Also, we already have alpha-scaled B available in + b_vec_0 and b_vec_1. Thus, we could scale these + registers with the element from A using AVX2 ISA */ - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); - _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + // Broadcasting real and imaginary components from A. - _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); - _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + a_real = _mm256_broadcast_sd((double const *)(temp_ai)); + a_imag = _mm256_broadcast_sd((double const *)(temp_ai) + 1); - _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); - _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + // Obtaining the alpha-scaled B matrix + b_scaled_0 = b_vec_0; + b_perm_0 = _mm256_permute_pd(b_scaled_0, 0x5); - _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); - _mm256_storeu_pd((double *)(temp_c + ldc*4 + 2), ymm12); + b_perm_0 = _mm256_mul_pd(b_perm_0, a_imag); + b_scaled_0 = _mm256_fmaddsub_pd(b_scaled_0, a_real, b_perm_0); - _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); - _mm256_storeu_pd((double *)(temp_c + ldc*5 + 2), ymm14); + c_element_0 = _mm256_castpd256_pd128(b_scaled_0); - temp_c+=Z_MR; - temp_a+=Z_MR; - } + // Clearing out the upper lanes of 256 bit registers to avoid + // the transition penalty + _mm256_zeroupper(); - // Fringe cases for M - dim_t m_rem=m_remainder; - if(m_rem>=2) - { - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - - ymm13 = _mm256_setzero_pd(); - - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); - //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) - // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); - //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) - // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); - - - if(beta_real != 0.0) + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - // beta_real*R(c[1][2]) beta_real*I(c[1][2]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - // beta_real*R(c[1][3]) beta_real*I(c[1][3]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); - - //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) - // beta_real*R(c[1][4]) beta_real*I(c[1][4]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); - - //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); - //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) - // beta_real*R(c[1][5]) beta_real*I(c[1][5]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); - } + case BLIS_MUL_ZERO : + break; + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); + c_element_0 = _mm_add_pd(c_element_0, b_element_0); + break; - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); - - //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) - // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); - - //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); - //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) - // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); - } + default : + // Broadcast beta real and imaginary part and scale with C. + beta_real_reg = _mm_loaddup_pd((double const*)beta); + beta_imag_reg = _mm_loaddup_pd((double const*)beta + 1); + + // Load C onto registers + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + // Shuffle for the compute with imgarinary part scaling + c_perm_0 = _mm_shuffle_pd(b_element_0, b_element_0, 0x01); - The results are accumalated in accordance to the non zero scalar values. - */ + c_perm_0 = _mm_mul_pd(beta_imag_reg, c_perm_0); - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); - _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); - _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); - _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); - _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); + b_element_0 = _mm_mul_pd(beta_real_reg, b_element_0); + // Compute beta-scaled C + b_element_0 = _mm_addsub_pd(b_element_0, c_perm_0); + // Add to intermediate reg storing alpha*A*B + c_element_0 = _mm_add_pd(b_element_0, c_element_0); + } - temp_c+=2; - temp_a+=2; + // Storing the result in C. + _mm_storeu_pd((double *)(temp_cij), c_element_0); - m_rem -= 2; + // We need to restore the upper lanes of the registers b_vec_0, b_vec_1, + // b_vec_2 and b_vec_3 + // They need to contain the alpha scaled B, to be reused in the main loop for m + b_element_0 = _mm256_castpd256_pd128(b_vec_0); + b_vec_0 = _mm256_insertf128_pd(b_vec_0, b_element_0, 0x01); + + // Adjusting the addresses of A and C for the next block. + temp_cij += 1; + temp_ai += 1; + + m_rem -= 1; } - if(m_rem==1) + if( m_rem == 2 ) { + // Scratch registers. + __m256d c_vec_0; - xmm5 = _mm_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - - ymm13 = _mm256_setzero_pd(); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); - //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) - // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); - //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) - // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); - - if(beta_real != 0.0) + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + + // Loading a vector from A with 2 elements. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) - //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) - //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); + case BLIS_MUL_ZERO : + break; + + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + break; + + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); + + BETA_SCALING_C_FRINGE(c_vec_0, temp_cij); } - if(beta_imag != 0.0) + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + + // Adjusting the addresses of A and C for the next block. + temp_cij += 2; + temp_ai += 2; + + m_rem -= 2; + } + + // Main loop along M dimension. + for( dim_t i = 0; i < m_iter; i++ ) + { + // Scratch registers + __m256d c_vec_0, c_vec_1; + + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + c_vec_1 = _mm256_setzero_pd(); + + // Prefetching the block of C to be used for computation. + _mm_prefetch((char const*)(temp_cij), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*2), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*3), _MM_HINT_T0); + + // Loading vectors from A with Z_MR elements in total. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + a_vec_1 = _mm256_loadu_pd((double const *)(temp_ai + 2)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_1 = _mm256_mul_pd(a_vec_1, bdcst_0); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + c_vec_1 = _mm256_fmaddsub_pd(a_vec_1, bdcst_0, c_vec_1); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) - //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) - //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); - } + case BLIS_MUL_ZERO : + break; - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(temp_c), xmm5); + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + 2)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + c_vec_1 = _mm256_add_pd(c_vec_1, a_vec_1); + break; - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + BETA_SCALING_C_MAIN(c_vec_0, c_vec_1, temp_cij); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + } - xmm5 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(temp_c + ldc*4), xmm5); + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + _mm256_storeu_pd((double *)(temp_cij + 2), c_vec_1); - xmm5 = _mm256_extractf128_pd(ymm13, 0); - _mm_storeu_pd((double *)(temp_c + ldc*5), xmm5); + // Adjusting the addresses of A and C for the next iteration. + temp_cij += Z_MR; + temp_ai += Z_MR; } + temp_b += ldb; + temp_c += ldc; + + n_remainder -= 1; } - //Fringe case for N - if(n_remainder>=4) + if( n_remainder == 2 ) { - dcomplex* temp_b = b + (n - n_remainder)*ldb; - dcomplex* temp_a = a; - dcomplex* temp_c = c + (n - n_remainder)*ldc; - - //Main loop for M - for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + // Setting the panel addresses for A B, and C, to be used along m-loop + dcomplex *temp_ai = a; + dcomplex *temp_bj = temp_b; + dcomplex *temp_cij = temp_c; + + /* Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x 2 block of B in order to compute the associated Z_MR x 2 + and/or m_remainder x 2 block of C. This reusability has been exploited, wherein + the associated 1 x 2 block of B is scaled with alpha, and stored in + registers beforehand, to be reused in the main loop or fringe case of m. */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m256d a_vec_0, a_vec_1; + __m256d b_vec_0, b_vec_1; + __m256d b_real_0, b_real_1; + __m256d b_imag_0, b_imag_1; + __m256d bdcst_0, bdcst_1; + + /* Broadcasting real and imaginary components of elements from B + and unpacking them to set them in registers in the form : + { Real_part, Imag_part, Real_part, Imag_part }. + + A total of Z_NR registers are used to store the alpha-scaled B + for reuse. */ + + b_real_0 = _mm256_broadcast_sd((double const *)(temp_bj)); + b_imag_0 = _mm256_broadcast_sd((double const *)(temp_bj) + 1); + b_vec_0 = _mm256_unpacklo_pd(b_real_0, b_imag_0); + + b_real_1 = _mm256_broadcast_sd((double const *)(temp_bj + ldb)); + b_imag_1 = _mm256_broadcast_sd((double const *)(temp_bj + ldb) + 1); + b_vec_1 = _mm256_unpacklo_pd(b_real_1, b_imag_1); + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec_0 = _mm256_broadcast_sd((double const *)(&alpha_real)); + a_vec_1 = _mm256_broadcast_sd((double const *)(&alpha_imag)); + + bdcst_0 = _mm256_unpacklo_pd(b_imag_0, b_real_0); + bdcst_1 = _mm256_unpacklo_pd(b_imag_1, b_real_1); + bdcst_0 = _mm256_mul_pd(a_vec_1, bdcst_0); + bdcst_1 = _mm256_mul_pd(a_vec_1, bdcst_1); + b_vec_0 = _mm256_fmaddsub_pd(a_vec_0, b_vec_0, bdcst_0); + b_vec_1 = _mm256_fmaddsub_pd(a_vec_0, b_vec_1, bdcst_1); + + // Fringe cases in the m-direction. + dim_t m_rem = m_remainder; + if ( ( m_rem & 0x1 ) == 1 ) { - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag - where alpha_real and/or alpha_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); - - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) - I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] - R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) - I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); - - /* - a. Perform beta*C using temp_c, beta_real, - where beta_real is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) - // beta_real*R(c[3][0]) beta_real*I(c[3][0]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) - // beta_real*R(c[3][1]) beta_real*I(c[3][1]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - // beta_real*R(c[1][2]) beta_real*I(c[1][2]) - //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) - // beta_real*R(c[3][2]) beta_real*I(c[3][2]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - // beta_real*R(c[1][3]) beta_real*I(c[1][3]) - //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) - // beta_real*R(c[3][3]) beta_real*I(c[3][3]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); - } - /* - a. Perform beta*C using temp_c, beta_imag, - where beta_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) - // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) - // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) - //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) - // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) - //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) - // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); - } - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + // Scratch registers. + __m256d b_scaled_0, b_perm_0, a_real, a_imag; - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ + __m128d b_element_0, b_element_1, c_element_0, c_element_1; + __m128d beta_real_reg, beta_imag_reg, c_perm_0, c_perm_1; - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + b_scaled_0 = _mm256_setzero_pd(); + b_perm_0 = _mm256_setzero_pd(); - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); - _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + /* Here, only a single element from A is of concern. + Also, we already have alpha-scaled B available in + b_vec_0 and b_vec_1. Thus, we could scale these + registers with the element from A using AVX2 ISA */ - _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); - _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + // Broadcasting real and imaginary components from A. - _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); - _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + a_real = _mm256_broadcast_sd((double const *)(temp_ai)); + a_imag = _mm256_broadcast_sd((double const *)(temp_ai) + 1); - temp_c+=Z_MR; - temp_a+=Z_MR; - } + // Obtaining the alpha-scaled B matrix - // Fringe cases for M - dim_t m_rem=m_remainder; - if(m_rem>=2) - { - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - - - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - // beta_real*R(c[1][2]) beta_real*I(c[1][2]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - // beta_real*R(c[1][3]) beta_real*I(c[1][3]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); - } + b_scaled_0 = _mm256_permute2f128_pd(b_vec_0, b_vec_1, 0x20); + b_perm_0 = _mm256_permute_pd(b_scaled_0, 0x5); + + b_perm_0 = _mm256_mul_pd(b_perm_0, a_imag); + b_scaled_0 = _mm256_fmaddsub_pd(b_scaled_0, a_real, b_perm_0); - if(beta_imag != 0.0) + c_element_0 = _mm256_castpd256_pd128(b_scaled_0); + c_element_1 = _mm256_extractf128_pd(b_scaled_0, 0x01); + + // Clearing out the upper lanes of 256 bit registers to avoid + // the transition penalty + _mm256_zeroupper(); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - - //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); - - //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); - } + case BLIS_MUL_ZERO : + break; - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); + c_element_0 = _mm_add_pd(c_element_0, b_element_0); - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc)); + c_element_1 = _mm_add_pd(c_element_1, b_element_1); + break; - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); - _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); - _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + default : + // Broadcast beta real and imaginary part and scale with C. + beta_real_reg = _mm_loaddup_pd((double const*)beta); + beta_imag_reg = _mm_loaddup_pd((double const*)beta + 1); - temp_c+=2; - temp_a+=2; + // Load C onto registers + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc)); - m_rem -= 2; - } + // Shuffle for the compute with imgarinary part scaling + c_perm_0 = _mm_shuffle_pd(b_element_0, b_element_0, 0x01); + c_perm_1 = _mm_shuffle_pd(b_element_1, b_element_1, 0x01); - if(m_rem==1) - { + c_perm_0 = _mm_mul_pd(beta_imag_reg, c_perm_0); + c_perm_1 = _mm_mul_pd(beta_imag_reg, c_perm_1); - xmm5 = _mm_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); - } + b_element_0 = _mm_mul_pd(beta_real_reg, b_element_0); + b_element_1 = _mm_mul_pd(beta_real_reg, b_element_1); - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); - - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); - } + // Compute beta-scaled C + b_element_0 = _mm_addsub_pd(b_element_0, c_perm_0); + b_element_1 = _mm_addsub_pd(b_element_1, c_perm_1); - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(temp_c), xmm5); + // Add to intermediate reg storing alpha*A*B + c_element_0 = _mm_add_pd(b_element_0, c_element_0); + c_element_1 = _mm_add_pd(b_element_1, c_element_1); + } - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + // Storing the result in C. + _mm_storeu_pd((double *)(temp_cij), c_element_0); + _mm_storeu_pd((double *)(temp_cij + ldc), c_element_1); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + // We need to restore the upper lanes of the registers b_vec_0, b_vec_1, + // b_vec_2 and b_vec_3 + // They need to contain the alpha scaled B, to be reused in the main loop for m + b_element_0 = _mm256_castpd256_pd128(b_vec_0); + b_element_1 = _mm256_extractf128_pd(b_vec_1, 0x00); + b_vec_0 = _mm256_insertf128_pd(b_vec_0, b_element_0, 0x01); + b_vec_1 = _mm256_insertf128_pd(b_vec_1, b_element_1, 0x01); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + // Adjusting the addresses of A and C for the next block. + temp_cij += 1; + temp_ai += 1; + m_rem -= 1; } - n_remainder -= 4; - } - if(n_remainder>=2) - { - dcomplex* temp_b = b + (n - n_remainder)*ldb; - dcomplex* temp_a = a; - dcomplex* temp_c = c + (n - n_remainder)*ldc; - for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + if( m_rem == 2 ) { - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag - where alpha_real and/or alpha_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); - - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) - I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] - R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) - I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - - /* - a. Perform beta*C using temp_c, beta_real, - where beta_real is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) - // beta_real*R(c[3][0]) beta_real*I(c[3][0]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) - // beta_real*R(c[3][1]) beta_real*I(c[3][1]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); - } + // Scratch registers. + __m256d c_vec_0, c_vec_2; - /* - a. Perform beta*C using temp_c, beta_imag, - where beta_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) - // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) - // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); - } - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading - - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ - - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + 2), ymm4); - - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); - _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); - - temp_c+=Z_MR; - temp_a+=Z_MR; - } + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + c_vec_2 = _mm256_setzero_pd(); - dim_t m_rem=m_remainder; - if(m_rem>=2) - { - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - - - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - // beta_real*R(c[1][1]) beta_real*I(c[1][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - } + // Loading a vector from A with 2 elements. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpackhi_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_2 = _mm256_mul_pd(a_vec_0, bdcst_1); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); - if(beta_imag != 0.0) + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpacklo_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + c_vec_2 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_2); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - - //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - } + case BLIS_MUL_ZERO : + break; + + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc)); + c_vec_2 = _mm256_add_pd(c_vec_2, a_vec_0); + break; - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ + BETA_SCALING_C_FRINGE(c_vec_0, temp_cij); + BETA_SCALING_C_FRINGE(c_vec_2, temp_cij + ldc); - _mm256_storeu_pd((double *)(temp_c), ymm3); - _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + } + + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + _mm256_storeu_pd((double *)(temp_cij + ldc), c_vec_2); - temp_c+=2; - temp_a+=2; + // Adjusting the addresses of A and C for the next block. + temp_cij += 2; + temp_ai += 2; m_rem -= 2; } - if(m_rem==1) + // Main loop along M dimension. + for( dim_t i = 0; i < m_iter; i++ ) { - - xmm5 = _mm_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - - if(beta_real != 0.0) + // Scratch registers + __m256d c_vec_0, c_vec_1, c_vec_2, c_vec_3; + + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + c_vec_1 = _mm256_setzero_pd(); + c_vec_2 = _mm256_setzero_pd(); + c_vec_3 = _mm256_setzero_pd(); + + // Prefetching the block of C to be used for computation. + _mm_prefetch((char const*)(temp_cij), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*2), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*3), _MM_HINT_T0); + + // Loading vectors from A with Z_MR elements in total. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + a_vec_1 = _mm256_loadu_pd((double const *)(temp_ai + 2)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpackhi_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_1 = _mm256_mul_pd(a_vec_1, bdcst_0); + c_vec_2 = _mm256_mul_pd(a_vec_0, bdcst_1); + c_vec_3 = _mm256_mul_pd(a_vec_1, bdcst_1); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpacklo_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + c_vec_1 = _mm256_fmaddsub_pd(a_vec_1, bdcst_0, c_vec_1); + c_vec_2 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_2); + c_vec_3 = _mm256_fmaddsub_pd(a_vec_1, bdcst_1, c_vec_3); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + case BLIS_MUL_ZERO : + break; + + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + 2)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + c_vec_1 = _mm256_add_pd(c_vec_1, a_vec_1); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + ldc + 2)); + c_vec_2 = _mm256_add_pd(c_vec_2, a_vec_0); + c_vec_3 = _mm256_add_pd(c_vec_3, a_vec_1); + break; + + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); + + BETA_SCALING_C_MAIN(c_vec_0, c_vec_1, temp_cij); + BETA_SCALING_C_MAIN(c_vec_2, c_vec_3, temp_cij + ldc); - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); } - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + _mm256_storeu_pd((double *)(temp_cij + 2), c_vec_1); - xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + _mm256_storeu_pd((double *)(temp_cij + ldc), c_vec_2); + _mm256_storeu_pd((double *)(temp_cij + ldc + 2), c_vec_3); - xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - } + // Adjusting the addresses of A and C for the next iteration. + temp_cij += Z_MR; + temp_ai += Z_MR; - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(temp_c), xmm5); + } - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + temp_b += ldb*2; + temp_c += ldc*2; - } n_remainder -= 2; } - if(n_remainder==1) - { - dcomplex* temp_b = b + (n - n_remainder)*ldb; - dcomplex* temp_a = a; - dcomplex* temp_c = c + (n - n_remainder)*ldc; - // Main loop for M - for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + // Main loop along N dimension + for( dim_t j = 0; j < n_iter; j++ ) + { + dcomplex* temp_bj = temp_b + j * ldb * Z_NR; + dcomplex* temp_ai = a; + dcomplex* temp_cij = temp_c + j * ldc * Z_NR; + + /* Multiple blocks of Z_MR x 1(main loop for m) and/or m_remainder x 1 block(s) + of A use the same 1 x Z_NR block of B in order to compute the associated + Z_MR x Z_NR and/or m_remainder x Z_NR block(s) of C. This reusability has been + exploited, wherein the associated 1 x Z_NR block of B is scaled with alpha, + and stored in registers beforehand, to be reused in the main loop or fringe + case of m. */ + + // Intermediate registers used for alpha scaling the block of B and storing. + __m256d a_vec_0, a_vec_1; + __m256d b_vec_0, b_vec_1, b_vec_2, b_vec_3; + __m256d b_real_0, b_real_1, b_real_2, b_real_3; + __m256d b_imag_0, b_imag_1, b_imag_2, b_imag_3; + __m256d bdcst_0, bdcst_1; + + /* Broadcasting real and imaginary components of elements from B + and unpacking them to set them in registers in the form : + { Real_part, Imag_part, Real_part, Imag_part }. + + A total of Z_NR registers are used to store the alpha-scaled B + for reuse. */ + + b_real_0 = _mm256_broadcast_sd((double const *)(temp_bj)); + b_imag_0 = _mm256_broadcast_sd((double const *)(temp_bj) + 1); + b_vec_0 = _mm256_unpacklo_pd(b_real_0, b_imag_0); + + b_real_1 = _mm256_broadcast_sd((double const *)(temp_bj + ldb)); + b_imag_1 = _mm256_broadcast_sd((double const *)(temp_bj + ldb) + 1); + b_vec_1 = _mm256_unpacklo_pd(b_real_1, b_imag_1); + + b_real_2 = _mm256_broadcast_sd((double const *)(temp_bj + ldb*2)); + b_imag_2 = _mm256_broadcast_sd((double const *)(temp_bj + ldb*2) + 1); + b_vec_2 = _mm256_unpacklo_pd(b_real_2, b_imag_2); + + b_real_3 = _mm256_broadcast_sd((double const *)(temp_bj + ldb*3)); + b_imag_3 = _mm256_broadcast_sd((double const *)(temp_bj + ldb*3) + 1); + b_vec_3 = _mm256_unpacklo_pd(b_real_3, b_imag_3); + + // Broadcast elements from alpha, and exhibit the compute for complex scaling. + a_vec_0 = _mm256_broadcast_sd((double const *)(&alpha_real)); + a_vec_1 = _mm256_broadcast_sd((double const *)(&alpha_imag)); + + bdcst_0 = _mm256_unpacklo_pd(b_imag_0, b_real_0); + bdcst_1 = _mm256_unpacklo_pd(b_imag_1, b_real_1); + bdcst_0 = _mm256_mul_pd(a_vec_1, bdcst_0); + bdcst_1 = _mm256_mul_pd(a_vec_1, bdcst_1); + b_vec_0 = _mm256_fmaddsub_pd(a_vec_0, b_vec_0, bdcst_0); + b_vec_1 = _mm256_fmaddsub_pd(a_vec_0, b_vec_1, bdcst_1); + + bdcst_0 = _mm256_unpacklo_pd(b_imag_2, b_real_2); + bdcst_1 = _mm256_unpacklo_pd(b_imag_3, b_real_3); + bdcst_0 = _mm256_mul_pd(a_vec_1, bdcst_0); + bdcst_1 = _mm256_mul_pd(a_vec_1, bdcst_1); + b_vec_2 = _mm256_fmaddsub_pd(a_vec_0, b_vec_2, bdcst_0); + b_vec_3 = _mm256_fmaddsub_pd(a_vec_0, b_vec_3, bdcst_1); + + // Fringe cases in the m-direction. + dim_t m_rem = m_remainder; + if ( ( m_rem & 0x1 ) == 1 ) { - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - - - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_real, aplha_vali - where alpha_real and/or alpha_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); - - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) - I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] - R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) - I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - - /* - a. Perform beta*C using temp_c, beta_real, - where beta_real is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) - // beta_real*R(c[3][0]) beta_real*I(c[3][0]) - SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); - } + // Scratch registers. + __m256d b_scaled_0, b_perm_0, a_real, a_imag; - /* - a. Perform beta*C using temp_c, beta_imag, - where beta_imag is not zero. - b. This loop operates with 4x6 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - if(beta_imag != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); - - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) - // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) - SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); - } - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading - - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ - - //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) - _mm256_storeu_pd((double *)(temp_c), ymm3); - //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) - _mm256_storeu_pd((double *)(temp_c + 2), ymm4); - - temp_c+=Z_MR; - temp_a+=Z_MR; - } + __m128d b_element_0, b_element_1; + __m128d c_element_0, c_element_1, c_element_2, c_element_3; + __m128d beta_real_reg, beta_imag_reg, c_perm_0, c_perm_1; - // Fringe cases for M - dim_t m_rem=m_remainder; - if(m_rem>=2) - { - ymm3 = _mm256_setzero_pd(); + b_scaled_0 = _mm256_setzero_pd(); + b_perm_0 = _mm256_setzero_pd(); + /* Here, only a single element from A is of concern. + Also, we already have alpha-scaled B available in + b_vec_0 and b_vec_1. Thus, we could scale these + registers with the element from A using AVX2 ISA */ - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + // Broadcasting real and imaginary components from A. - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + a_real = _mm256_broadcast_sd((double const *)(temp_ai)); + a_imag = _mm256_broadcast_sd((double const *)(temp_ai) + 1); - /* - The result after scaling with alpha_real and/or alpha_imag is as follows: - For ymm0 : - R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) - I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] - R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) - I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) - */ + // Obtaining the alpha-scaled B matrix - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + b_scaled_0 = _mm256_permute2f128_pd(b_vec_0, b_vec_1, 0x20); + b_perm_0 = _mm256_permute_pd(b_scaled_0, 0x5); - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); + b_perm_0 = _mm256_mul_pd(b_perm_0, a_imag); + b_scaled_0 = _mm256_fmaddsub_pd(b_scaled_0, a_real, b_perm_0); - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + c_element_0 = _mm256_castpd256_pd128(b_scaled_0); + c_element_1 = _mm256_extractf128_pd(b_scaled_0, 0x01); + b_scaled_0 = _mm256_permute2f128_pd(b_vec_2, b_vec_3, 0x20); + b_perm_0 = _mm256_permute_pd(b_scaled_0, 0x5); - if(beta_real != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + b_perm_0 = _mm256_mul_pd(b_perm_0, a_imag); + b_scaled_0 = _mm256_fmaddsub_pd(b_scaled_0, a_real, b_perm_0); - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - // beta_real*R(c[1][0]) beta_real*I(c[1][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); - } + c_element_2 = _mm256_castpd256_pd128(b_scaled_0); + c_element_3 = _mm256_extractf128_pd(b_scaled_0, 0x01); + + // Clearing out the upper lanes of 256 bit registers to avoid + // the transition penalty + _mm256_zeroupper(); - if(beta_imag != 0.0) + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + case BLIS_MUL_ZERO : + break; - ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - } - /* - The scaling has been done sequentially as follows: - - If alpha_real is not 0, it is used for scaling A - - If alpha_imag is not 0, it is used for scaling A using permutation - and selective negation, after loading - - If beta_real is not 0, is is used for scaling C - - If beta_imag is not 0, it is used for scaling C using permutation - and selective negation, after loading + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); + c_element_0 = _mm_add_pd(c_element_0, b_element_0); - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases - */ + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc)); + c_element_1 = _mm_add_pd(c_element_1, b_element_1); - _mm256_storeu_pd((double *)(temp_c), ymm3); + b_element_0 = _mm_loadu_pd((double const*)(temp_cij + ldc*2)); + c_element_2 = _mm_add_pd(c_element_2, b_element_0); - temp_c+=2; - temp_a+=2; + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc*3)); + c_element_3 = _mm_add_pd(c_element_3, b_element_1); + break; - m_rem -= 2; - } + default : + // Broadcast beta real and imaginary part and scale with C. + beta_real_reg = _mm_loaddup_pd((double const*)beta); + beta_imag_reg = _mm_loaddup_pd((double const*)beta + 1); - if(m_rem==1) - { + // Load C onto registers + b_element_0 = _mm_loadu_pd((double const*)(temp_cij)); + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc)); - xmm5 = _mm_setzero_pd(); - ymm3 = _mm256_setzero_pd(); + // Shuffle for the compute with imgarinary part scaling + c_perm_0 = _mm_shuffle_pd(b_element_0, b_element_0, 0x01); + c_perm_1 = _mm_shuffle_pd(b_element_1, b_element_1, 0x01); - xmm5 = _mm_loadu_pd((double const*)(temp_a)); - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + c_perm_0 = _mm_mul_pd(beta_imag_reg, c_perm_0); + c_perm_1 = _mm_mul_pd(beta_imag_reg, c_perm_1); - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + b_element_0 = _mm_mul_pd(beta_real_reg, b_element_0); + b_element_1 = _mm_mul_pd(beta_real_reg, b_element_1); - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + // Compute beta-scaled C + b_element_0 = _mm_addsub_pd(b_element_0, c_perm_0); + b_element_1 = _mm_addsub_pd(b_element_1, c_perm_1); - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); + // Add to intermediate reg storing alpha*A*B + c_element_0 = _mm_add_pd(b_element_0, c_element_0); + c_element_1 = _mm_add_pd(b_element_1, c_element_1); - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + // Load C onto registers + b_element_0 = _mm_loadu_pd((double const*)(temp_cij + ldc*2)); + b_element_1 = _mm_loadu_pd((double const*)(temp_cij + ldc*3)); + + // Shuffle for the compute with imgarinary part scaling + c_perm_0 = _mm_shuffle_pd(b_element_0, b_element_0, 0x01); + c_perm_1 = _mm_shuffle_pd(b_element_1, b_element_1, 0x01); + + c_perm_0 = _mm_mul_pd(beta_imag_reg, c_perm_0); + c_perm_1 = _mm_mul_pd(beta_imag_reg, c_perm_1); + + b_element_0 = _mm_mul_pd(beta_real_reg, b_element_0); + b_element_1 = _mm_mul_pd(beta_real_reg, b_element_1); + + // Compute beta-scaled C + b_element_0 = _mm_addsub_pd(b_element_0, c_perm_0); + b_element_1 = _mm_addsub_pd(b_element_1, c_perm_1); + + // Add to intermediate reg storing alpha*A*B + c_element_2 = _mm_add_pd(b_element_0, c_element_2); + c_element_3 = _mm_add_pd(b_element_1, c_element_3); + } + + // Storing the result in C. + _mm_storeu_pd((double *)(temp_cij), c_element_0); + _mm_storeu_pd((double *)(temp_cij + ldc), c_element_1); + _mm_storeu_pd((double *)(temp_cij + ldc*2), c_element_2); + _mm_storeu_pd((double *)(temp_cij + ldc*3), c_element_3); + + // We need to restore the upper lanes of the registers b_vec_0, b_vec_1, + // b_vec_2 and b_vec_3 + // They need to contain the alpha scaled B, to be reused in the main loop for m + b_element_0 = _mm256_castpd256_pd128(b_vec_0); + b_element_1 = _mm256_castpd256_pd128(b_vec_1); + b_vec_0 = _mm256_insertf128_pd(b_vec_0, b_element_0, 0x01); + b_vec_1 = _mm256_insertf128_pd(b_vec_1, b_element_1, 0x01); + + b_element_0 = _mm256_castpd256_pd128(b_vec_2); + b_element_1 = _mm256_castpd256_pd128(b_vec_3); + b_vec_2 = _mm256_insertf128_pd(b_vec_2, b_element_0, 0x01); + b_vec_3 = _mm256_insertf128_pd(b_vec_3, b_element_1, 0x01); + + // Adjusting the addresses of A and C for the next block. + temp_cij += 1; + temp_ai += 1; + + m_rem -= 1; + } - if(beta_real != 0.0) + if( m_rem >= 2 ) + { + // Scratch registers. + __m256d c_vec_0, c_vec_2, c_vec_4, c_vec_6; + + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + c_vec_2 = _mm256_setzero_pd(); + c_vec_4 = _mm256_setzero_pd(); + c_vec_6 = _mm256_setzero_pd(); + + // Loading a vector from A with 2 elements. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpackhi_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_2 = _mm256_mul_pd(a_vec_0, bdcst_1); + + bdcst_0 = _mm256_unpackhi_pd(b_vec_2, b_vec_2); + bdcst_1 = _mm256_unpackhi_pd(b_vec_3, b_vec_3); + c_vec_4 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_6 = _mm256_mul_pd(a_vec_0, bdcst_1); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpacklo_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + c_vec_2 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_2); + + bdcst_0 = _mm256_unpacklo_pd(b_vec_2, b_vec_2); + bdcst_1 = _mm256_unpacklo_pd(b_vec_3, b_vec_3); + c_vec_4 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_4); + c_vec_6 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_6); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); + case BLIS_MUL_ZERO : + break; + + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc)); + c_vec_2 = _mm256_add_pd(c_vec_2, a_vec_0); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc*2)); + c_vec_4 = _mm256_add_pd(c_vec_4, a_vec_0); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc*3)); + c_vec_6 = _mm256_add_pd(c_vec_6, a_vec_0); + break; + + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); + + BETA_SCALING_C_FRINGE(c_vec_0, temp_cij); + BETA_SCALING_C_FRINGE(c_vec_2, temp_cij + ldc); + BETA_SCALING_C_FRINGE(c_vec_4, temp_cij + ldc*2); + BETA_SCALING_C_FRINGE(c_vec_6, temp_cij + ldc*3); - xmm5 = _mm_loadu_pd((double const*)(temp_c)); - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) - SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); } - if(beta_imag != 0.0) + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + _mm256_storeu_pd((double *)(temp_cij + ldc), c_vec_2); + _mm256_storeu_pd((double *)(temp_cij + ldc*2), c_vec_4); + _mm256_storeu_pd((double *)(temp_cij + ldc*3), c_vec_6); + + // Adjusting the addresses of A and C for the next block. + temp_cij += 2; + temp_ai += 2; + + m_rem -= 2; + } + + // Main loop along M dimension. + for( dim_t i = 0; i < m_iter; i++ ) + { + // Scratch registers. + __m256d c_vec_0, c_vec_1, c_vec_2, c_vec_3; + __m256d c_vec_4, c_vec_5, c_vec_6, c_vec_7; + + a_vec_0 = _mm256_setzero_pd(); + a_vec_1 = _mm256_setzero_pd(); + bdcst_0 = _mm256_setzero_pd(); + bdcst_1 = _mm256_setzero_pd(); + c_vec_0 = _mm256_setzero_pd(); + c_vec_1 = _mm256_setzero_pd(); + c_vec_2 = _mm256_setzero_pd(); + c_vec_3 = _mm256_setzero_pd(); + c_vec_4 = _mm256_setzero_pd(); + c_vec_5 = _mm256_setzero_pd(); + c_vec_6 = _mm256_setzero_pd(); + c_vec_7 = _mm256_setzero_pd(); + + _mm_prefetch((char const*)(temp_cij), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*2), _MM_HINT_T0); + _mm_prefetch((char const*)(temp_cij + ldc*3), _MM_HINT_T0); + + // Loading vectors from A with Z_MR elements in total. + a_vec_0 = _mm256_loadu_pd((double const *)(temp_ai)); + a_vec_1 = _mm256_loadu_pd((double const *)(temp_ai + 2)); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with imaginary components of elements from B. + bdcst_0 = _mm256_unpackhi_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpackhi_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_1 = _mm256_mul_pd(a_vec_1, bdcst_0); + c_vec_2 = _mm256_mul_pd(a_vec_0, bdcst_1); + c_vec_3 = _mm256_mul_pd(a_vec_1, bdcst_1); + + bdcst_0 = _mm256_unpackhi_pd(b_vec_2, b_vec_2); + bdcst_1 = _mm256_unpackhi_pd(b_vec_3, b_vec_3); + c_vec_4 = _mm256_mul_pd(a_vec_0, bdcst_0); + c_vec_5 = _mm256_mul_pd(a_vec_1, bdcst_0); + c_vec_6 = _mm256_mul_pd(a_vec_0, bdcst_1); + c_vec_7 = _mm256_mul_pd(a_vec_1, bdcst_1); + + a_vec_0 = _mm256_permute_pd(a_vec_0, 0x5); + a_vec_1 = _mm256_permute_pd(a_vec_1, 0x5); + + // Scaling with real components of elements from B. + bdcst_0 = _mm256_unpacklo_pd(b_vec_0, b_vec_0); + bdcst_1 = _mm256_unpacklo_pd(b_vec_1, b_vec_1); + c_vec_0 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_0); + c_vec_1 = _mm256_fmaddsub_pd(a_vec_1, bdcst_0, c_vec_1); + c_vec_2 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_2); + c_vec_3 = _mm256_fmaddsub_pd(a_vec_1, bdcst_1, c_vec_3); + + bdcst_0 = _mm256_unpacklo_pd(b_vec_2, b_vec_2); + bdcst_1 = _mm256_unpacklo_pd(b_vec_3, b_vec_3); + c_vec_4 = _mm256_fmaddsub_pd(a_vec_0, bdcst_0, c_vec_4); + c_vec_5 = _mm256_fmaddsub_pd(a_vec_1, bdcst_0, c_vec_5); + c_vec_6 = _mm256_fmaddsub_pd(a_vec_0, bdcst_1, c_vec_6); + c_vec_7 = _mm256_fmaddsub_pd(a_vec_1, bdcst_1, c_vec_7); + + // Scaling with beta, according to its type. + switch( beta_mul_type ) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); + case BLIS_MUL_ZERO : + break; + + case BLIS_MUL_ONE : + // Load C and add with the corresponding scratch register. + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + 2)); + c_vec_0 = _mm256_add_pd(c_vec_0, a_vec_0); + c_vec_1 = _mm256_add_pd(c_vec_1, a_vec_1); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + ldc + 2)); + c_vec_2 = _mm256_add_pd(c_vec_2, a_vec_0); + c_vec_3 = _mm256_add_pd(c_vec_3, a_vec_1); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc*2)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + ldc*2 + 2)); + c_vec_4 = _mm256_add_pd(c_vec_4, a_vec_0); + c_vec_5 = _mm256_add_pd(c_vec_5, a_vec_1); + + a_vec_0 = _mm256_loadu_pd((double const*)(temp_cij + ldc*3)); + a_vec_1 = _mm256_loadu_pd((double const*)(temp_cij + ldc*3 + 2)); + c_vec_6 = _mm256_add_pd(c_vec_6, a_vec_0); + c_vec_7 = _mm256_add_pd(c_vec_7, a_vec_1); + break; + + default : + // Broadcast beta and redirect to the beta scaling macro. + bdcst_0 = _mm256_broadcast_sd((double const*)(&beta_real)); + bdcst_1 = _mm256_broadcast_sd((double const*)(&beta_imag)); + + BETA_SCALING_C_MAIN(c_vec_0, c_vec_1, temp_cij); + BETA_SCALING_C_MAIN(c_vec_2, c_vec_3, temp_cij + ldc); + BETA_SCALING_C_MAIN(c_vec_4, c_vec_5, temp_cij + ldc*2); + BETA_SCALING_C_MAIN(c_vec_6, c_vec_7, temp_cij + ldc*3); - xmm5 = _mm_loadu_pd((double const*)(temp_c)); - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) - SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); } - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(temp_c), xmm5); + // Storing the result in C. + _mm256_storeu_pd((double *)(temp_cij), c_vec_0); + _mm256_storeu_pd((double *)(temp_cij + 2), c_vec_1); + + _mm256_storeu_pd((double *)(temp_cij + ldc), c_vec_2); + _mm256_storeu_pd((double *)(temp_cij + ldc + 2), c_vec_3); + + _mm256_storeu_pd((double *)(temp_cij + ldc*2), c_vec_4); + _mm256_storeu_pd((double *)(temp_cij + ldc*2 + 2), c_vec_5); + + _mm256_storeu_pd((double *)(temp_cij + ldc*3), c_vec_6); + _mm256_storeu_pd((double *)(temp_cij + ldc*3 + 2), c_vec_7); + // Adjusting the addresses of A and C for the next iteration. + temp_cij += Z_MR; + temp_ai += Z_MR; } } diff --git a/kernels/zen/3/bli_zgemm_zen_2x6.c b/kernels/zen/3/bli_zgemm_zen_2x6.c new file mode 100644 index 0000000000..e29537bda8 --- /dev/null +++ b/kernels/zen/3/bli_zgemm_zen_2x6.c @@ -0,0 +1,652 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 4 +#define B_L1_PREFETCH_DIST 4 +#define TAIL_NITER 4 +#define PREFETCH_A +// #define PREFETCH_B +// #define PREFETCH_A_NEXT +// #define PREFETCH_B_NEXT +#define PREFETCH_C // perfetch c in middle loop over 2 iterations of k +// #define PREFETCH_C_SLOW // prefetch c in middle loop over 4 iterations of k +// #define PREFETCH_C_SIMPL // prefetch c before k loop + + +#ifdef PREFETCH_A + #define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*2*16 + (2*n+k)*(16))) +#else + #define PREFETCH_A_L1(n, k) +#endif + +#ifdef PREFETCH_B + #define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*6*16 + (6*n+(2*k))*(16))) +#else + #define PREFETCH_B_L1(n, k) +#endif + + +/* + * A Registers: YMM3 + * B Registers: YMM0, YMM1, YMM2 + * C Registers: YMM[4-15] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0)\ + VBROADCASTSD(YMM(3), MEM(RAX,(4*n+0)*8)) \ + VFMADD231PD(YMM(4), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(5), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(6), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM(3), MEM(RAX,(4*n+1)*8)) \ + VFMADD231PD(YMM(7), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(8), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(9), YMM(2), YMM(3)) \ + \ + PREFETCH_B_L1(n, 0)\ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+2)*8)) \ + VFMADD231PD(YMM(10), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(11), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(12), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+3)*8)) \ + VFMADD231PD(YMM(13), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(14), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(15), YMM(2), YMM(3)) \ + \ + VMOVAPD(YMM(0), MEM(RBX,(6*n+0)*16)) \ + VMOVAPD(YMM(1), MEM(RBX,(6*n+2)*16)) \ + VMOVAPD(YMM(2), MEM(RBX,(6*n+4)*16)) \ + \ + + +/**********************************************************/ +/* Kernel : bli_zgemm_zen_asm_2x6 */ +/* It performs C = C * beta + alpha * A * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : YMM(3) */ +/* load B : YMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* YMM(4-6,10-12) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* YMM(7-9,13-15) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* YMM(4-6,10-12) */ +/**********************************************************/ +void bli_zgemm_zen_asm_2x6( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + + + char beta_mul_type = BLIS_MUL_DEFAULT; + if(beta->imag == 0.0 && beta->real == 0.0 ) + { + beta_mul_type = BLIS_MUL_ZERO; + } + + BEGIN_ASM() + + VXORPD(YMM( 4), YMM( 4), YMM( 4)) + VXORPD(YMM( 5), YMM( 5), YMM( 5)) + VMOVAPD(YMM(6) , YMM(4)) + VMOVAPD(YMM(7) , YMM(4)) + VMOVAPD(YMM(8) , YMM(4)) + VMOVAPD(YMM(9) , YMM(4)) + VMOVAPD(YMM(10), YMM(4)) + VMOVAPD(YMM(11), YMM(4)) + VMOVAPD(YMM(12), YMM(4)) + VMOVAPD(YMM(13), YMM(4)) + VMOVAPD(YMM(14), YMM(4)) + VMOVAPD(YMM(15), YMM(4)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a)) //load address of a + MOV(RBX, VAR(b)) //load address of b + MOV(RCX, VAR(c)) //load address of c + + #ifdef PREFETCH_C + LEA(R9, MEM(RCX, 63)) // c for prefetch, first cache line + LEA(R8, MEM(RCX, 95)) // c for prefetch, second cache line + #endif + + + VMOVAPD(YMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(YMM(1), MEM(RBX, 2*16)) //pre-load b + VMOVAPD(YMM(2), MEM(RBX, 4*16)) //pre-load b + LEA(RBX, MEM(RBX,6*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + #if defined PREFETCH_A_NEXT || defined PREFETCH_B_NEXT + MOV(RDI, RSI) + IMUL(RDI, IMM(16*2)) // rdi = k * 16*2 + #endif + + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(RAX, RDI, 1)) // r14(a_next) = A + (k*16*2) + #endif + + #ifdef PREFETCH_B_NEXT + IMUL(RDI, IMM(3)) // rdi = k * 16*6 + LEA(R15, MEM(RBX, RDI, 1)) // r15(b_next) = B + (k*16*6) + #endif + + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* Prefetch_C_SIMPLE: */ + /* LOOP1: k/4 - TAIL_NITER */ + /* LOOP2: 0 */ + /* LOOP3: 0 */ + /* LOOP4: TAIL_NITER */ + /* PREFETCH_C_SLOW: */ + /* LOOP1: k/4 - TAIL_NITER - 4 */ + /* LOOP2: 2 */ + /* LOOP3: 2 */ + /* LOOP4: TAIL_NITER */ + /* PREFETCH_C: */ + /* LOOP1: k/4 - TAIL_NITER - 2 */ + /* LOOP2: 2 */ + /* LOOP3: 0 */ + /* LOOP4: TAIL_NITER */ + /************************************************************/ + #ifdef PREFETCH_C + #ifdef PREFETCH_C_SIMPLE + /* prefetch c over 1 iteration of k*/ + SUB(RDI, IMM(0+TAIL_NITER)) + #elif defined PREFETCH_C_SLOW + /* prefetch c over 4 iterations of k*/ + SUB(RDI, IMM(4+TAIL_NITER)) + #else + /* prefetch c over 2 iterations of k*/ + SUB(RDI, IMM(2+TAIL_NITER)) + #endif + #endif + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + #ifdef PREFETCH_A_NEXT + PREFETCH(1, MEM(R14)) + #endif + SUBITER(0) + #ifdef PREFETCH_B_NEXT + PREFETCH(1, MEM(R15)) + #endif + SUBITER(1) + #ifdef PREFETCH_A_NEXT + PREFETCH(1, MEM(R14, 64)) + #endif + SUB(RDI, IMM(1)) + SUBITER(2) + #ifdef PREFETCH_B_NEXT + PREFETCH(1, MEM(R15, 64)) + #endif + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(R14,128)) + #endif + #ifdef PREFETCH_B_NEXT + LEA(R15, MEM(R15,64)) + #endif + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + +#ifdef PREFETCH_C +#if defined PREFETCH_C_SIMPLE + /*****************************/ + /* prefetch 2x6 of C at once */ + /*****************************/ + PREFETCH(0, MEM(R9)) + PREFETCH(0, MEM(R9, 31)) + PREFETCH(0, MEM(R9,R12, 1)) + PREFETCH(0, MEM(R9,R12, 1, 31)) + PREFETCH(0, MEM(R9,R12, 2)) + PREFETCH(0, MEM(R9,R12, 2, 31)) +#else + ADD(RDI, IMM(2)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + #ifdef PREFETCH_C + PREFETCH(0, MEM(R9)) + #endif + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + #ifndef PREFETCH_C_SLOW + /************************************************/ + /* if prefetch is being done over 2 iterations, */ + /* prefetch 2 cache lines per iteration */ + /* prefetch one row of C per iteration of Loop2 */ + /************************************************/ + PREFETCH(0, MEM(R9,31)) + #endif + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + #ifdef PREFETCH_C + LEA(R9, MEM(R9,R12,1)) + #endif + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + #ifdef PREFETCH_C_SLOW + ADD(RDI, IMM(2)) + JLE(K_TAIL_NITER_2) + + LOOP_ALIGN + LABEL(LOOP3) + #ifdef PREFETCH_C + PREFETCH(0, MEM(R8)) + #endif + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + #ifdef PREFETCH_C + LEA(R8, MEM(R8,R12,1)) + #endif + JNZ(LOOP3) + LABEL(K_TAIL_NITER_2) + + #endif //PREFETCH_C_SLOW + +#endif //PREFETCH_C_SIMPLE + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP4) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + JNZ(LOOP4) + +#endif //PREFETCH_C + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,2*16)) + LEA(RBX, MEM(RBX,6*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + VPERMILPD(YMM( 7), YMM( 7), IMM(0x5)) + VPERMILPD(YMM( 8), YMM( 8), IMM(0x5)) + VPERMILPD(YMM( 9), YMM( 9), IMM(0x5)) + VPERMILPD(YMM(13), YMM(13), IMM(0x5)) + VPERMILPD(YMM(14), YMM(14), IMM(0x5)) + VPERMILPD(YMM(15), YMM(15), IMM(0x5)) + + VADDSUBPD(YMM(4), YMM(4), YMM(7)) + VADDSUBPD(YMM(5), YMM(5), YMM(8)) + VADDSUBPD(YMM(6), YMM(6), YMM(9)) + + VADDSUBPD(YMM(10), YMM(10), YMM(13)) + VADDSUBPD(YMM(11), YMM(11), YMM(14)) + VADDSUBPD(YMM(12), YMM(12), YMM(15)) + + /******************/ + /* scale by alpha */ + /******************/ + MOV(RAX, VAR(alpha)) + VBROADCASTSD(YMM(0), MEM(RAX)) + VBROADCASTSD(YMM(1), MEM(RAX, 8)) + + VPERMILPD(YMM(3), YMM(4), IMM(0X5)) + VMULPD(YMM(4), YMM(4), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(4), YMM(4), YMM(3)) + + VPERMILPD(YMM(3), YMM(5), IMM(0X5)) + VMULPD(YMM(5), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(5), YMM(5), YMM(3)) + + VPERMILPD(YMM(3), YMM(6), IMM(0X5)) + VMULPD(YMM(6), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(6), YMM(6), YMM(3)) + + // ROW 2 + VPERMILPD(YMM(3), YMM(10), IMM(0X5)) + VMULPD(YMM(10), YMM(10), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(10), YMM(10), YMM(3)) + + VPERMILPD(YMM(3), YMM(11), IMM(0X5)) + VMULPD(YMM(11), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(11), YMM(11), YMM(3)) + + VPERMILPD(YMM(3), YMM(12), IMM(0X5)) + VMULPD(YMM(12), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(12), YMM(12), YMM(3)) + + + MOV(RBX, VAR(beta)) + VBROADCASTSD(YMM(1), MEM(RBX)) + VBROADCASTSD(YMM(2), MEM(RBX, 8)) + + + MOV(AL, VAR(beta_mul_type)) + CMP(AL, IMM(0)) + JE(.ZBETAZERO) + + CMP(R10, IMM(16)) //CS == 1 IMPLIES ROW STORED + JNZ(.ZCOLSTORED) + + LABEL(.ZROWSTORED) + LEA(RDX, MEM(RCX, R12, 1)) + + // ROW 1 + VMOVUPD(YMM(0), MEM(RCX)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(4)) + VMOVUPD(MEM(RCX), YMM(0)) + + VMOVUPD(YMM(0), MEM(RCX, R10, 2)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(5)) + VMOVUPD(MEM(RCX, R10, 2), YMM(0)) + + VMOVUPD(YMM(0), MEM(RCX, R10, 4)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(6)) + VMOVUPD(MEM(RCX, R10, 4), YMM(0)) + + //ROW 2 + VMOVUPD(YMM(0), MEM(RDX)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(10)) + VMOVUPD(MEM(RDX), YMM(0)) + + VMOVUPD(YMM(0), MEM(RDX, R10, 2)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(11)) + VMOVUPD(MEM(RDX, R10, 2), YMM(0)) + + VMOVUPD(YMM(0), MEM(RDX, R10, 4)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(12)) + VMOVUPD(MEM(RDX, R10, 4), YMM(0)) + + JMP(.ZDONE) + + LABEL(.ZCOLSTORED) + LEA(RDX, MEM(RCX, R12, 1)) + LEA(RDI, MEM(, R10, 2)) + + VMOVUPD(XMM(0), MEM(RCX )) + VMOVUPD(XMM(3), MEM(RCX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(4)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(0)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VMOVUPD(XMM(0), MEM(RCX )) + VMOVUPD(XMM(3), MEM(RCX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(5)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(0)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VMOVUPD(XMM(0), MEM(RCX )) + VMOVUPD(XMM(3), MEM(RCX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(6)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(0)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + + + VMOVUPD(XMM(0), MEM(RDX )) + VMOVUPD(XMM(3), MEM(RDX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(10)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(0)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VMOVUPD(XMM(0), MEM(RDX )) + VMOVUPD(XMM(3), MEM(RDX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(11)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(0)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VMOVUPD(XMM(0), MEM(RDX )) + VMOVUPD(XMM(3), MEM(RDX, R10, 1)) + VINSERTF128(YMM(0), YMM(0), XMM(3), IMM(0x1)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VADDPD(YMM(0), YMM(0), YMM(12)) + VEXTRACTF128(XMM(3), YMM(0), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(0)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + + JMP(.ZDONE) + + LABEL(.ZBETAZERO) + CMP(R12, IMM(16)) + JNZ(.ZROWSTORBZ) + + LABEL(.ZCOLSTORBZ) + LEA(RDX, MEM(RCX, R12, 1)) + LEA(RDI, MEM(, R10, 2)) + + VEXTRACTF128(XMM(3), YMM(4), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(4)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(5), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(5)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(6), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(6)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + + + VEXTRACTF128(XMM(3), YMM(10), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(10)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(11), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(11)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(12), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(12)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + JMP(.ZDONE) + + + LABEL(.ZROWSTORBZ) + LEA(RDX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX), YMM(4)) + VMOVUPD(MEM(RCX, R10, 2), YMM(5)) + VMOVUPD(MEM(RCX, R10, 4), YMM(6)) + + VMOVUPD(MEM(RDX), YMM(10)) + VMOVUPD(MEM(RDX, R10, 2), YMM(11)) + VMOVUPD(MEM(RDX, R10, 4), YMM(12)) + + + + LABEL(.ZDONE) + + + VZEROUPPER() + + END_ASM + ( + : // output operands (none) + : // input operands + [beta_mul_type] "m" (beta_mul_type), + [k] "m" (k), + [a] "m" (a), + [b] "m" (b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", + "ymm13", "ymm14", "ymm15", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", + "xmm13", "xmm14", "xmm15", + "memory" + ) +} diff --git a/kernels/zen/3/bli_zgemmtrsm_l_2x6.c b/kernels/zen/3/bli_zgemmtrsm_l_2x6.c new file mode 100644 index 0000000000..4a8d7c1b1d --- /dev/null +++ b/kernels/zen/3/bli_zgemmtrsm_l_2x6.c @@ -0,0 +1,559 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 4 +#define B_L1_PREFETCH_DIST 4 +#define TAIL_NITER 6 + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*2*16 + (2*n+k)*(16))) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*6*16 + (2*n+k)*(48))) + +/* + * A Registers: YMM3 + * B Registers: YMM0, YMM1, YMM2 + * C Registers: YMM[4-15] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 0)*8)) \ + VFMADD231PD(YMM( 4), YMM(0), YMM(3)) \ + VFMADD231PD(YMM( 5), YMM(1), YMM(3)) \ + VFMADD231PD(YMM( 6), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 1)*8)) \ + VFMADD231PD(YMM( 7), YMM(0), YMM(3)) \ + VFMADD231PD(YMM( 8), YMM(1), YMM(3)) \ + VFMADD231PD(YMM( 9), YMM(2), YMM(3)) \ + \ + PREFETCH_B_L1(n, 0) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 2)*8)) \ + VFMADD231PD(YMM(10), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(11), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(12), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 3)*8)) \ + VFMADD231PD(YMM(13), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(14), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(15), YMM(2), YMM(3)) \ + \ + VMOVAPD(YMM(0), MEM(RBX,(6*n+0)*16)) \ + VMOVAPD(YMM(1), MEM(RBX,(6*n+2)*16)) \ + VMOVAPD(YMM(2), MEM(RBX,(6*n+4)*16)) \ + +// used for division of complex number if TRSM_PREINV is disabled +static double negative[4] __attribute__((aligned(64))) + = {-1, -1, -1, -1}; + +/**********************************************************/ +/* Kernel : bli_zgemmtrsm_l_zen_asm_2x6 */ +/* It performs A * X = alpha * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : YMM(3) */ +/* load B : YMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* YMM(4-6,10-12) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* YMM(7-9,13-15) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* YMM(4-6,10-12) */ +/**********************************************************/ +void bli_zgemmtrsm_l_zen_asm_2x6 + ( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a10, + dcomplex* restrict a11, + dcomplex* restrict b01, + dcomplex* restrict b11, + dcomplex* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + const double* negPtr = &negative[0]; + + + BEGIN_ASM() + + VXORPD(YMM( 4), YMM( 4), YMM( 4)) + VXORPD(YMM( 5), YMM( 5), YMM( 5)) + VMOVAPD(YMM(6) , YMM(4)) + VMOVAPD(YMM(7) , YMM(4)) + VMOVAPD(YMM(8) , YMM(4)) + VMOVAPD(YMM(9) , YMM(4)) + VXORPD(YMM(10), YMM(10), YMM(10)) + VXORPD(YMM(11), YMM(11), YMM(11)) + VMOVAPD(YMM(12), YMM(4)) + VMOVAPD(YMM(13), YMM(4)) + VMOVAPD(YMM(14), YMM(4)) + VMOVAPD(YMM(15), YMM(4)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a10)) //load address of a + MOV(RBX, VAR(b01)) //load address of b + MOV(RCX, VAR(b11)) //load address of c + MOV(R9, VAR(c11)) // load C for prefetch + MOV(R11, VAR(negPtr)) + + VMOVAPD(YMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(YMM(1), MEM(RBX, 2*16)) //pre-load b + VMOVAPD(YMM(2), MEM(RBX, 4*16)) //pre-load b + LEA(RBX, MEM(RBX,6*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* Loop counts: */ + /* LOOP1: k/4 - TAIL_NITER - 2 */ + /* LOOP2: 2 <--prefetch_c */ + /* LOOP4: TAIL_NITER */ + /************************************************************/ + SUB(RDI, IMM(2+TAIL_NITER)) + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + + ADD(RDI, IMM(2)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + + PREFETCH(0, MEM(R9)) + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + PREFETCH(0, MEM(R9,64)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + LEA(R9, MEM(R9,R12,1)) + + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + JNZ(LOOP3) + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,2*16)) + LEA(RBX, MEM(RBX,6*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + /**************************************************/ + /* Permute imag component register. Shuffle even */ + /* and odd components */ + /* SRC: YMM7 =(Ai0*Br0, Ai0*Bi0, Ai0*Br1, Ai0*Bi1)*/ + /* DST: YMM7 =(Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1)*/ + /**************************************************/ + VPERMILPD(YMM( 7), YMM( 7), IMM(0x5)) + VPERMILPD(YMM( 8), YMM( 8), IMM(0x5)) + VPERMILPD(YMM( 9), YMM( 9), IMM(0x5)) + VPERMILPD(YMM(13), YMM(13), IMM(0x5)) + VPERMILPD(YMM(14), YMM(14), IMM(0x5)) + VPERMILPD(YMM(15), YMM(15), IMM(0x5)) + + /***************************************************/ + /* SRC: YMM4 = (Ar0*Br0, Ar0*Bi0, Ar0*Br1, Ar0*Bi1)*/ + /* SRC: YMM7 = (Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1)*/ + /* DST: YMM4 =(Ar0*Br0-Ai0*Bi0, Ai0*Br0+Ar0*Bi0, */ + /* Ar0*Br1-Ai0*Bi1, Ai0*Br1+Ar0*Bi1) */ + /***************************************************/ + VADDSUBPD(YMM(4), YMM(4), YMM(7)) + VADDSUBPD(YMM(5), YMM(5), YMM(8)) + VADDSUBPD(YMM(6), YMM(6), YMM(9)) + VADDSUBPD(YMM(10), YMM(10), YMM(13)) + VADDSUBPD(YMM(11), YMM(11), YMM(14)) + VADDSUBPD(YMM(12), YMM(12), YMM(15)) + + /*Load alpha*/ + MOV(R9, VAR(alpha)) + VBROADCASTSD(YMM(7), MEM(R9)) + VBROADCASTSD(YMM(8), MEM(R9, 8)) + MOV(RDX, RCX) + MOV(RDI, IMM(6*16)) + + VMOVUPD(YMM(0), MEM(RDX, 0*16)) + VMOVUPD(YMM(1), MEM(RDX, 2*16)) + VMOVUPD(YMM(2), MEM(RDX, 4*16)) + ADD(RDX, RDI) + + /************************************************************************/ + /* gemm_output -= C * alpha */ + /* */ + /* Let C * alpha = (a + ib) * (c + id) */ + /* (a + ib) * (c + id) = (ac - bd) + i(ad + bc) */ + /* */ + /*Steps: */ + /* YMM(0) = a0, b0, a1, b1 */ + /* YMM(3) = b0, a0, b1, a1 */ + /* YMM(0) = a0*c0, b0*c0, a1*c1, b1*c1 */ + /* YMM(3) = b0*d0, a0*d0, b1*d1, a1*d1 */ + /* YMM(0) = (a0c0 - b0d0), (b0c0 + a0d0), (a1c1 - b1d1), (b1c1 + a1d1) */ + /************************************************************************/ + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(7)) // a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(8)) // b*d, a*d + VADDSUBPD(YMM(0), YMM(0), YMM(3)) // ac - bd, bc + ad + VSUBPD(YMM(4), YMM(0), YMM(4)) // gemm_output - c * alpha + + VMOVUPD(YMM(0), MEM(RDX, 0*16)) + VPERMILPD(YMM(3), YMM(1), IMM(0x5)) + VMULPD(YMM(1), YMM(1), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(1), YMM(1), YMM(3)) + VSUBPD(YMM(5), YMM(1), YMM(5)) + + VMOVUPD(YMM(1), MEM(RDX, 2*16)) + VPERMILPD(YMM(3), YMM(2), IMM(0x5)) + VMULPD(YMM(2), YMM(2), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(2), YMM(2), YMM(3)) + VSUBPD(YMM(6), YMM(2), YMM(6)) + + VMOVUPD(YMM(2), MEM(RDX, 4*16)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VSUBPD(YMM(10), YMM(0), YMM(10)) + + VPERMILPD(YMM(3), YMM(1), IMM(0x5)) + VMULPD(YMM(1), YMM(1), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(1), YMM(1), YMM(3)) + VSUBPD(YMM(11), YMM(1), YMM(11)) + + VPERMILPD(YMM(3), YMM(2), IMM(0x5)) + VMULPD(YMM(2), YMM(2), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(2), YMM(2), YMM(3)) + VSUBPD(YMM(12), YMM(2), YMM(12)) + + + // REGION - TRSM + MOV(RAX, VAR(a11)) + //iteration 0 ------------------------------------- + VBROADCASTSD(YMM(0), MEM(RAX, (0+0*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (0+0*2)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + /****************************************************/ + /* C = C * A11 */ + /* (a + ib) * (c + id) = (ac - bd) + i(ad + bc) */ + /****************************************************/ + VPERMILPD(YMM(3), YMM(4), IMM(0x5)) + VMULPD(YMM(4), YMM(4), YMM(0)) //a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(1)) //b*d, a*d + VADDSUBPD(YMM(4), YMM(4), YMM(3)) // (ac - bd), (bc + ad) + + VPERMILPD(YMM(3), YMM(5), IMM(0x5)) + VMULPD(YMM(5), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(5), YMM(5), YMM(3)) + + VPERMILPD(YMM(3), YMM(6), IMM(0x5)) + VMULPD(YMM(6), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(6), YMM(6), YMM(3)) + #else + /************************************************************************/ + /* C = C / A11 */ + /* */ + /* Let C / A11 = (a + ib) / (c + id) = */ + /* ((ac + bd) / (c^2 + d^2)) + i ((bc - ad) / (c^2+d^2)) */ + /* */ + /*Steps: */ + /* YMM(4) = a0, b0, a1, b1 */ + /* YMM(3) = b0, a0, b1, a1 */ + /* YMM(4) = a0*c0, b0*c0, a1*c1, b1*c1 */ + /* YMM(3) = b0*d0, a0*d0, b1*d1, a1*d1 */ + /* YMM(3) = -b0*d0, -a0*d0, -b1*d1, -a1*d1 */ + /* YMM(4) = (a0c0 - b0d0), (b0c0 + a0d0), (a1c1 - b1d1), (b1c1 + a1d1) */ + /* YMM(4) = (a0c0 - b0d0) / (c^2 + d^2), (b0c0 + a0d0) / (c^2 + d^2), */ + /* (a1c1 - b1d1) / (c^2 + d^2), (b1c1 + a1d1 / (c^2 + d^2) */ + /************************************************************************/ + VMOVUPD(YMM(2), MEM(R11)) // -1 + VMULPD(YMM(9), YMM(0), YMM(0)) + VFMADD231PD(YMM(9), YMM(1), YMM(1)) + + VPERMILPD(YMM(3), YMM(4), IMM(0x5)) + VMULPD(YMM(4), YMM(4), YMM(0)) // a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(1)) // b*d, a*d + VMULPD(YMM(3), YMM(3), YMM(2)) // -bd, -ad + VADDSUBPD(YMM(4), YMM(4), YMM(3)) // ac + bd, bc - ad + VDIVPD(YMM(4), YMM(4), YMM(9)) // (ac + bd) / (c^2 + d^2), (bc - ad) / (c^2 + d^2) + + VPERMILPD(YMM(3), YMM(5), IMM(0x5)) + VMULPD(YMM(5), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(5), YMM(5), YMM(3)) + VDIVPD(YMM(5), YMM(5), YMM(9)) + + VPERMILPD(YMM(3), YMM(6), IMM(0x5)) + VMULPD(YMM(6), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(6), YMM(6), YMM(3)) + VDIVPD(YMM(6), YMM(6), YMM(9)) + #endif + VMOVUPD(MEM(RCX, 0*16), YMM(4)) + VMOVUPD(MEM(RCX, 2*16), YMM(5)) + VMOVUPD(MEM(RCX, 4*16), YMM(6)) + ADD(RCX, RDI) + + //iteration 1 ------------------------------------- + + VBROADCASTSD(YMM(0), MEM(RAX, (1+0*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (1+0*2)*16+8)) + + VPERMILPD(YMM(3), YMM(4), IMM(0x5)) + VMULPD(YMM(2), YMM(4), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(7), YMM(2), YMM(3)) + + VPERMILPD(YMM(3), YMM(5), IMM(0x5)) + VMULPD(YMM(2), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(8), YMM(2), YMM(3)) + + VPERMILPD(YMM(3), YMM(6), IMM(0x5)) + VMULPD(YMM(2), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(9), YMM(2), YMM(3)) + + VSUBPD(YMM(10), YMM(10), YMM(7)) + VSUBPD(YMM(11), YMM(11), YMM(8)) + VSUBPD(YMM(12), YMM(12), YMM(9)) + + VBROADCASTSD(YMM(0), MEM(RAX, (1+1*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (1+1*2)*16+8)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VPERMILPD(YMM(3), YMM(10), IMM(0x5)) + VMULPD(YMM(10), YMM(10), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(10), YMM(10), YMM(3)) + + VPERMILPD(YMM(3), YMM(11), IMM(0x5)) + VMULPD(YMM(11), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(11), YMM(11), YMM(3)) + + VPERMILPD(YMM(3), YMM(12), IMM(0x5)) + VMULPD(YMM(12), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(12), YMM(12), YMM(3)) + #else + VMOVUPD(YMM(2), MEM(R11)) + VMULPD(YMM(9), YMM(0), YMM(0)) + VFMADD231PD(YMM(9), YMM(1), YMM(1)) + + VPERMILPD(YMM(3), YMM(10), IMM(0x5)) + VMULPD(YMM(10), YMM(10), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(10), YMM(10), YMM(3)) + VDIVPD(YMM(10), YMM(10), YMM(9)) + + VPERMILPD(YMM(3), YMM(11), IMM(0x5)) + VMULPD(YMM(11), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(11), YMM(11), YMM(3)) + VDIVPD(YMM(11), YMM(11), YMM(9)) + + VPERMILPD(YMM(3), YMM(12), IMM(0x5)) + VMULPD(YMM(12), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(12), YMM(12), YMM(3)) + VDIVPD(YMM(12), YMM(12), YMM(9)) + #endif + VMOVUPD(MEM(RCX, 0*16), YMM(10)) + VMOVUPD(MEM(RCX, 2*16), YMM(11)) + VMOVUPD(MEM(RCX, 4*16), YMM(12)) + +// ENDREGION - TRSM + + MOV(RAX, R12) + MOV(RBX, R10) + MOV(RCX, VAR(c11)) + + CMP(RBX, IMM(16)) + JE(ROWUPDATE) + + LABEL(COLUPDATE) + LEA(RDX, MEM(RCX, R12, 1)) + LEA(RDI, MEM(, R10, 2)) + + VEXTRACTF128(XMM(3), YMM(4), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(4)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(5), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(5)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(6), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(6)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + + + VEXTRACTF128(XMM(3), YMM(10), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(10)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(11), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(11)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(12), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(12)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + JMP(END) + + + LABEL(ROWUPDATE) + LEA(RDX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX ), YMM(4)) + VMOVUPD(MEM(RCX, R10, 2), YMM(5)) + VMOVUPD(MEM(RCX, R10, 4), YMM(6)) + + VMOVUPD(MEM(RDX ), YMM(10)) + VMOVUPD(MEM(RDX, R10, 2), YMM(11)) + VMOVUPD(MEM(RDX, R10, 4), YMM(12)) + JMP(END) + + LABEL(END) + + VZEROUPPER() + + + END_ASM + ( + : // output operands (none) + : // input operands + [a10] "m" (a10), + [k] "m" (k), + [b01] "m" (b01), + [a11] "m" (a11), + [b11] "m" (b11), + [c11] "m" (c11), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [alpha] "m" (alpha), + [negPtr] "m" (negPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", + "ymm13", "ymm14", "ymm15", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", + "xmm13", "xmm14", "xmm15", + "memory" + ) +} diff --git a/kernels/zen/3/bli_zgemmtrsm_u_2x6.c b/kernels/zen/3/bli_zgemmtrsm_u_2x6.c new file mode 100644 index 0000000000..12b5a61d99 --- /dev/null +++ b/kernels/zen/3/bli_zgemmtrsm_u_2x6.c @@ -0,0 +1,561 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 4 +#define B_L1_PREFETCH_DIST 4 +#define TAIL_NITER 6 + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*2*16 + (2*n+k)*(16))) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*6*16 + (2*n+k)*(48))) + + +/* + * A Registers: YMM3 + * B Registers: YMM0, YMM1, YMM2 + * C Registers: YMM[4-15] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 0)*8)) \ + VFMADD231PD(YMM( 4), YMM(0), YMM(3)) \ + VFMADD231PD(YMM( 5), YMM(1), YMM(3)) \ + VFMADD231PD(YMM( 6), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 1)*8)) \ + VFMADD231PD(YMM( 7), YMM(0), YMM(3)) \ + VFMADD231PD(YMM( 8), YMM(1), YMM(3)) \ + VFMADD231PD(YMM( 9), YMM(2), YMM(3)) \ + \ + PREFETCH_B_L1(n, 0) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 2)*8)) \ + VFMADD231PD(YMM(10), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(11), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(12), YMM(2), YMM(3)) \ + VBROADCASTSD(YMM( 3), MEM(RAX,(4*n+ 3)*8)) \ + VFMADD231PD(YMM(13), YMM(0), YMM(3)) \ + VFMADD231PD(YMM(14), YMM(1), YMM(3)) \ + VFMADD231PD(YMM(15), YMM(2), YMM(3)) \ + \ + VMOVAPD(YMM(0), MEM(RBX,(6*n+0)*16)) \ + VMOVAPD(YMM(1), MEM(RBX,(6*n+2)*16)) \ + VMOVAPD(YMM(2), MEM(RBX,(6*n+4)*16)) \ + +// used for division of complex number if TRSM_PREINV is disabled +static double negative[4] __attribute__((aligned(64))) + = {-1, -1, -1, -1}; + +/**********************************************************/ +/* Kernel : bli_zgemmtrsm_u_zen_asm_2x6 */ +/* It performs A * X = alpha * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : YMM(3) */ +/* load B : YMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* YMM(4-6,10-12) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* YMM(7-9,13-15) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* YMM(4-6,10-12) */ +/**********************************************************/ +void bli_zgemmtrsm_u_zen_asm_2x6 + ( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a10, + dcomplex* restrict a11, + dcomplex* restrict b01, + dcomplex* restrict b11, + dcomplex* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + const double* negPtr = &negative[0]; + + + BEGIN_ASM() + + VXORPD(YMM( 4), YMM( 4), YMM( 4)) + VXORPD(YMM( 5), YMM( 5), YMM( 5)) + VMOVAPD(YMM(6) , YMM(4)) + VMOVAPD(YMM(7) , YMM(4)) + VMOVAPD(YMM(8) , YMM(4)) + VMOVAPD(YMM(9) , YMM(4)) + VXORPD(YMM(10), YMM(10), YMM(10)) + VXORPD(YMM(11), YMM(11), YMM(11)) + VMOVAPD(YMM(12), YMM(4)) + VMOVAPD(YMM(13), YMM(4)) + VMOVAPD(YMM(14), YMM(4)) + VMOVAPD(YMM(15), YMM(4)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a10)) //load address of a + MOV(RBX, VAR(b01)) //load address of b + MOV(RCX, VAR(b11)) //load address of c + MOV(R9, VAR(c11)) // laod C for prefetch + MOV(R11, VAR(negPtr)) + + // MOV(R9, RCX) + + VMOVAPD(YMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(YMM(1), MEM(RBX, 2*16)) //pre-load b + VMOVAPD(YMM(2), MEM(RBX, 4*16)) //pre-load b + LEA(RBX, MEM(RBX,6*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* Loop counts: */ + /* LOOP1: k/4 - TAIL_NITER - 2 */ + /* LOOP2: 2 <--prefetch_c */ + /* LOOP4: TAIL_NITER */ + /************************************************************/ + SUB(RDI, IMM(2+TAIL_NITER)) + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + + ADD(RDI, IMM(2)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*2*16)) + LEA(RBX, MEM(RBX,4*6*16)) + + JNZ(LOOP3) + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,2*16)) + LEA(RBX, MEM(RBX,6*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + /**************************************************/ + /* Permute imag component register. Shuffle even */ + /* and odd components */ + /* SRC: YMM7 =(Ai0*Br0, Ai0*Bi0, Ai0*Br1, Ai0*Bi1)*/ + /* DST: YMM7 =(Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1)*/ + /**************************************************/ + VPERMILPD(YMM( 7), YMM( 7), IMM(0x5)) + VPERMILPD(YMM( 8), YMM( 8), IMM(0x5)) + VPERMILPD(YMM( 9), YMM( 9), IMM(0x5)) + VPERMILPD(YMM(13), YMM(13), IMM(0x5)) + VPERMILPD(YMM(14), YMM(14), IMM(0x5)) + VPERMILPD(YMM(15), YMM(15), IMM(0x5)) + + /***************************************************/ + /* SRC: YMM4 = (Ar0*Br0, Ar0*Bi0, Ar0*Br1, Ar0*Bi1)*/ + /* SRC: YMM7 = (Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1)*/ + /* DST: YMM4 =(Ar0*Br0-Ai0*Bi0, Ai0*Br0+Ar0*Bi0, */ + /* Ar0*Br1-Ai0*Bi1, Ai0*Br1+Ar0*Bi1) */ + /***************************************************/ + VADDSUBPD(YMM(4), YMM(4), YMM(7)) + VADDSUBPD(YMM(5), YMM(5), YMM(8)) + VADDSUBPD(YMM(6), YMM(6), YMM(9)) + VADDSUBPD(YMM(10), YMM(10), YMM(13)) + VADDSUBPD(YMM(11), YMM(11), YMM(14)) + VADDSUBPD(YMM(12), YMM(12), YMM(15)) + + /*Load alpha*/ + MOV(R9, VAR(alpha)) + VBROADCASTSD(YMM(7), MEM(R9)) + VBROADCASTSD(YMM(8), MEM(R9, 8)) + MOV(RDX, RCX) + MOV(RDI, IMM(6*16)) + + VMOVUPD(YMM(0), MEM(RDX, 0*16)) + VMOVUPD(YMM(1), MEM(RDX, 2*16)) + VMOVUPD(YMM(2), MEM(RDX, 4*16)) + ADD(RDX, RDI) + + /************************************************************************/ + /* gemm_output -= C * alpha */ + /* */ + /* Let C * alpha = (a + ib) * (c + id) */ + /* (a + ib) * (c + id) = (ac - bd) + i(ad + bc) */ + /* */ + /*Steps: */ + /* YMM(0) = a0, b0, a1, b1 */ + /* YMM(3) = b0, a0, b1, a1 */ + /* YMM(0) = a0*c0, b0*c0, a1*c1, b1*c1 */ + /* YMM(3) = b0*d0, a0*d0, b1*d1, a1*d1 */ + /* YMM(0) = (a0c0 - b0d0), (b0c0 + a0d0), (a1c1 - b1d1), (b1c1 + a1d1) */ + /************************************************************************/ + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(7)) // a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(8)) // b*d, a*d + VADDSUBPD(YMM(0), YMM(0), YMM(3)) // ac - bd, bc + ad + VSUBPD(YMM(4), YMM(0), YMM(4)) // gemm_output - c * alpha + + VMOVUPD(YMM(0), MEM(RDX, 0*16)) + VPERMILPD(YMM(3), YMM(1), IMM(0x5)) + VMULPD(YMM(1), YMM(1), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(1), YMM(1), YMM(3)) + VSUBPD(YMM(5), YMM(1), YMM(5)) + + VMOVUPD(YMM(1), MEM(RDX, 2*16)) + VPERMILPD(YMM(3), YMM(2), IMM(0x5)) + VMULPD(YMM(2), YMM(2), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(2), YMM(2), YMM(3)) + VSUBPD(YMM(6), YMM(2), YMM(6)) + + VMOVUPD(YMM(2), MEM(RDX, 4*16)) + VPERMILPD(YMM(3), YMM(0), IMM(0x5)) + VMULPD(YMM(0), YMM(0), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(0), YMM(0), YMM(3)) + VSUBPD(YMM(10), YMM(0), YMM(10)) + + VPERMILPD(YMM(3), YMM(1), IMM(0x5)) + VMULPD(YMM(1), YMM(1), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(1), YMM(1), YMM(3)) + VSUBPD(YMM(11), YMM(1), YMM(11)) + + VPERMILPD(YMM(3), YMM(2), IMM(0x5)) + VMULPD(YMM(2), YMM(2), YMM(7)) + VMULPD(YMM(3), YMM(3), YMM(8)) + VADDSUBPD(YMM(2), YMM(2), YMM(3)) + VSUBPD(YMM(12), YMM(2), YMM(12)) + + + MOV(RAX, VAR(a11)) + ADD(RCX, RDI) + // REGION - TRSM + //iteration 0 ------------------------------------- + VBROADCASTSD(YMM(0), MEM(RAX, (1+1*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (1+1*2)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + /****************************************************/ + /* C = C * A11 */ + /* (a + ib) * (c + id) = (ac - bd) + i(ad + bc) */ + /****************************************************/ + VPERMILPD(YMM(3), YMM(10), IMM(0x5)) + VMULPD(YMM(10), YMM(10), YMM(0)) //a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(1)) //b*d, a*d + VADDSUBPD(YMM(10), YMM(10), YMM(3)) // (ac - bd), (bc + ad) + + VPERMILPD(YMM(3), YMM(11), IMM(0x5)) + VMULPD(YMM(11), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(11), YMM(11), YMM(3)) + + VPERMILPD(YMM(3), YMM(12), IMM(0x5)) + VMULPD(YMM(12), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(12), YMM(12), YMM(3)) + #else + /************************************************************************/ + /* C = C / A11 */ + /* */ + /* Let C / A11 = (a + ib) / (c + id) = */ + /* ((ac + bd) / (c^2 + d^2)) + i ((bc - ad) / (c^2+d^2)) */ + /* */ + /*Steps: */ + /* YMM(10) = a0, b0, a1, b1 */ + /* YMM(3) = b0, a0, b1, a1 */ + /* YMM(10) = a0*c0, b0*c0, a1*c1, b1*c1 */ + /* YMM(3) = b0*d0, a0*d0, b1*d1, a1*d1 */ + /* YMM(3) = -b0*d0, -a0*d0, -b1*d1, -a1*d1 */ + /* YMM(10) = (a0c0 - b0d0), (b0c0 + a0d0), (a1c1 - b1d1), (b1c1 + a1d1) */ + /* YMM(10) = (a0c0 - b0d0) / (c^2 + d^2), (b0c0 + a0d0) / (c^2 + d^2), */ + /* (a1c1 - b1d1) / (c^2 + d^2), (b1c1 + a1d1 / (c^2 + d^2) */ + /************************************************************************/ + VMOVUPD(YMM(2), MEM(R11)) // -1 + VMULPD(YMM(9), YMM(0), YMM(0)) + VFMADD231PD(YMM(9), YMM(1), YMM(1)) + + VPERMILPD(YMM(3), YMM(10), IMM(0x5)) + VMULPD(YMM(10), YMM(10), YMM(0)) // a*c, b*c + VMULPD(YMM(3), YMM(3), YMM(1)) // b*d, a*d + VMULPD(YMM(3), YMM(3), YMM(2)) // -bd, -ad + VADDSUBPD(YMM(10), YMM(10), YMM(3)) // ac + bd, bc - ad + VDIVPD(YMM(10), YMM(10), YMM(9))//(ac + bd) / (c^2 + d^2),(bc - ad) / (c^2 + d^2) + + VPERMILPD(YMM(3), YMM(11), IMM(0x5)) + VMULPD(YMM(11), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(11), YMM(11), YMM(3)) + VDIVPD(YMM(11), YMM(11), YMM(9)) + + VPERMILPD(YMM(3), YMM(12), IMM(0x5)) + VMULPD(YMM(12), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(12), YMM(12), YMM(3)) + VDIVPD(YMM(12), YMM(12), YMM(9)) + + #endif + VMOVUPD(MEM(RCX, 0*16), YMM(10)) + VMOVUPD(MEM(RCX, 2*16), YMM(11)) + VMOVUPD(MEM(RCX, 4*16), YMM(12)) + SUB(RCX, RDI) + + //iteration 1 ------------------------------------- + + VBROADCASTSD(YMM(0), MEM(RAX, (0+1*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (0+1*2)*16+8)) + + VPERMILPD(YMM(3), YMM(10), IMM(0x5)) + VMULPD(YMM(2), YMM(10), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(7), YMM(2), YMM(3)) + + VPERMILPD(YMM(3), YMM(11), IMM(0x5)) + VMULPD(YMM(2), YMM(11), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(8), YMM(2), YMM(3)) + + VPERMILPD(YMM(3), YMM(12), IMM(0x5)) + VMULPD(YMM(2), YMM(12), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(9), YMM(2), YMM(3)) + + VSUBPD(YMM(4), YMM(4), YMM(7)) + VSUBPD(YMM(5), YMM(5), YMM(8)) + VSUBPD(YMM(6), YMM(6), YMM(9)) + + VBROADCASTSD(YMM(0), MEM(RAX, (0+0*2)*16+0)) + VBROADCASTSD(YMM(1), MEM(RAX, (0+0*2)*16+8)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VPERMILPD(YMM(3), YMM(4), IMM(0x5)) + VMULPD(YMM(4), YMM(4), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(4), YMM(4), YMM(3)) + + VPERMILPD(YMM(3), YMM(5), IMM(0x5)) + VMULPD(YMM(5), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(5), YMM(5), YMM(3)) + + VPERMILPD(YMM(3), YMM(6), IMM(0x5)) + VMULPD(YMM(6), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VADDSUBPD(YMM(6), YMM(6), YMM(3)) + #else + VMOVUPD(YMM(2), MEM(R11)) + VMULPD(YMM(9), YMM(0), YMM(0)) + VFMADD231PD(YMM(9), YMM(1), YMM(1)) + + VPERMILPD(YMM(3), YMM(4), IMM(0x5)) + VMULPD(YMM(4), YMM(4), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(4), YMM(4), YMM(3)) + VDIVPD(YMM(4), YMM(4), YMM(9)) + + VPERMILPD(YMM(3), YMM(5), IMM(0x5)) + VMULPD(YMM(5), YMM(5), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(5), YMM(5), YMM(3)) + VDIVPD(YMM(5), YMM(5), YMM(9)) + + VPERMILPD(YMM(3), YMM(6), IMM(0x5)) + VMULPD(YMM(6), YMM(6), YMM(0)) + VMULPD(YMM(3), YMM(3), YMM(1)) + VMULPD(YMM(3), YMM(3), YMM(2)) + VADDSUBPD(YMM(6), YMM(6), YMM(3)) + VDIVPD(YMM(6), YMM(6), YMM(9)) + #endif + VMOVUPD(MEM(RCX, 0*16), YMM(4)) + VMOVUPD(MEM(RCX, 2*16), YMM(5)) + VMOVUPD(MEM(RCX, 4*16), YMM(6)) + +// ENDREGION - TRSM + + MOV(RAX, R12) + MOV(RBX, R10) + MOV(RCX, VAR(c11)) + + CMP(RBX, IMM(16)) + JE(ROWUPDATE) + + LABEL(COLUPDATE) + LEA(RDX, MEM(RCX, R12, 1)) + LEA(RDI, MEM(, R10, 2)) + + VEXTRACTF128(XMM(3), YMM(4), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(4)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(5), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(5)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + ADD(RCX, RDI) + + VEXTRACTF128(XMM(3), YMM(6), IMM(0x1)) + VMOVUPD(MEM(RCX ), XMM(6)) + VMOVUPD(MEM(RCX, R10, 1), XMM(3)) + + + VEXTRACTF128(XMM(3), YMM(10), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(10)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(11), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(11)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + ADD(RDX, RDI) + + VEXTRACTF128(XMM(3), YMM(12), IMM(0x1)) + VMOVUPD(MEM(RDX ), XMM(12)) + VMOVUPD(MEM(RDX, R10, 1), XMM(3)) + JMP(END) + + + LABEL(ROWUPDATE) + LEA(RDX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX ), YMM(4)) + VMOVUPD(MEM(RCX, R10, 2), YMM(5)) + VMOVUPD(MEM(RCX, R10, 4), YMM(6)) + + VMOVUPD(MEM(RDX ), YMM(10)) + VMOVUPD(MEM(RDX, R10, 2), YMM(11)) + VMOVUPD(MEM(RDX, R10, 4), YMM(12)) + JMP(END) + + LABEL(END) + + VZEROUPPER() + + + END_ASM + ( + : // output operands (none) + : // input operands + [a10] "m" (a10), + [k] "m" (k), + [b01] "m" (b01), + [a11] "m" (a11), + [b11] "m" (b11), + [c11] "m" (c11), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [alpha] "m" (alpha), + [negPtr] "m" (negPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", + "ymm13", "ymm14", "ymm15", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", + "xmm13", "xmm14", "xmm15", + "memory" + ) +} diff --git a/kernels/zen/3/sup/CMakeLists.txt b/kernels/zen/3/sup/CMakeLists.txt deleted file mode 100644 index 57f3ee01ff..0000000000 --- a/kernels/zen/3/sup/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen_3_sup - OBJECT -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_s6x16.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_s6x16m.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_s6x16n.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_z3x4.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_z3x4m.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_asm_z3x4n.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_c3x8.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_c3x8m.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_c3x8n.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_s6x16.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_s6x16m.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_s6x16n.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_z3x4.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_z3x4m.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_asm_z3x4n.c - ) -target_compile_options(zen_3_sup PRIVATE /arch:AVX2) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen_3_sup PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c index 3c47a910bb..a5dafcfcc3 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c @@ -4,7 +4,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16m.c index 6d1d001b50..8e84a93c59 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16n.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16n.c index 6b84594e39..dbee5f43d3 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c index d07ee3ec07..6597742a9d 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4m.c index b8243a04ed..87f6ad122a 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4n.c index 8223e756f3..be8f493d36 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_z3x4n.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c index 386c2ca8f0..3ba935e806 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c index f92b1cc17b..a750143399 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index 77f0348561..8911e97d2c 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 2cb3a844cc..119ee626f6 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c index 19acd5a1b6..5b49833202 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ void bli_sgemmsup_rv_zen_asm_6x16m ( @@ -92,792 +92,957 @@ void bli_sgemmsup_rv_zen_asm_6x16m cntx_t* restrict cntx ) { - uint64_t n_left = n0 % 16; - - // First check whether this is a edge case in the n dimension. If so, - // dispatch other 6x?m kernels, as needed. - if (n_left ) - { - float* cij = c; - float* bj = b; - float* ai = a; - - if ( 8 <= n_left ) - { - const dim_t nr_cur = 8; - - bli_sgemmsup_rv_zen_asm_6x8m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - - if ( 4 <= n_left ) - { - const dim_t nr_cur = 4; - - bli_sgemmsup_rv_zen_asm_6x4m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_sgemmsup_rv_zen_asm_6x2m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - - if ( 1 == n_left ) - { - dim_t ps_a0 = bli_auxinfo_ps_a( data ); - if ( ps_a0 == 6 * rs_a0 ) - { - bli_sgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - else - { - const dim_t mr = 6; - - // Since A is packed into row panels, we must use a loop over - // gemv. - dim_t m_iter = ( m0 + mr - 1 ) / mr; - dim_t m_left = m0 % mr; - - float* restrict ai_ii = ai; - float* restrict cij_ii = cij; - - for ( dim_t ii = 0; ii < m_iter; ii += 1 ) - { - dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) - ? mr : m_left ); - - bli_sgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, - alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, - beta, cij_ii, rs_c0, cntx, NULL - ); - cij_ii += mr*rs_c0; ai_ii += ps_a0; - } - } - } - - return; - } - - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - - - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a4 = ps_a * sizeof( float ); - - if ( m_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 4), r8) // rs_a *= sizeof(dt) - lea(mem(, r9, 4), r9) // cs_a *= sizeof(dt) - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 4), r10) // rs_b *= sizeof(dt) - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(dt) - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.SLOOP6X16I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) - - mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. - mov(r14, rax) // reset rax to current upanel of a. - - cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored pre-fetching on c // not used - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - jmp(.SPOSTPFETCH) // jump to end of pre-fetching c - label(.SCOLPFETCH) // column-stored pre-fetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c - lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c - - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - label(.SLOOPKITER) // MAIN LOOP - - // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) - - vmovups(mem(rbx, 0*32), ymm0) - vmovups(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm1, ymm2, ymm5) - vfmadd231ps(ymm0, ymm3, ymm6) - vfmadd231ps(ymm1, ymm3, ymm7) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm1, ymm2, ymm9) - vfmadd231ps(ymm0, ymm3, ymm10) - vfmadd231ps(ymm1, ymm3, ymm11) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm1, ymm2, ymm13) - vfmadd231ps(ymm0, ymm3, ymm14) - vfmadd231ps(ymm1, ymm3, ymm15) - - // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - - vmovups(mem(rbx, 0*32), ymm0) - vmovups(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm1, ymm2, ymm5) - vfmadd231ps(ymm0, ymm3, ymm6) - vfmadd231ps(ymm1, ymm3, ymm7) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm1, ymm2, ymm9) - vfmadd231ps(ymm0, ymm3, ymm10) - vfmadd231ps(ymm1, ymm3, ymm11) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm1, ymm2, ymm13) - vfmadd231ps(ymm0, ymm3, ymm14) - vfmadd231ps(ymm1, ymm3, ymm15) - - // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) - - vmovups(mem(rbx, 0*32), ymm0) - vmovups(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm1, ymm2, ymm5) - vfmadd231ps(ymm0, ymm3, ymm6) - vfmadd231ps(ymm1, ymm3, ymm7) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm1, ymm2, ymm9) - vfmadd231ps(ymm0, ymm3, ymm10) - vfmadd231ps(ymm1, ymm3, ymm11) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm1, ymm2, ymm13) - vfmadd231ps(ymm0, ymm3, ymm14) - vfmadd231ps(ymm1, ymm3, ymm15) - - // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - - vmovups(mem(rbx, 0*32), ymm0) - vmovups(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm1, ymm2, ymm5) - vfmadd231ps(ymm0, ymm3, ymm6) - vfmadd231ps(ymm1, ymm3, ymm7) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm1, ymm2, ymm9) - vfmadd231ps(ymm0, ymm3, ymm10) - vfmadd231ps(ymm1, ymm3, ymm11) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm1, ymm2, ymm13) - vfmadd231ps(ymm0, ymm3, ymm14) - vfmadd231ps(ymm1, ymm3, ymm15) - - dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. - - label(.SCONSIDKLEFT) - - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - label(.SLOOPKLEFT) // EDGE LOOP - - vmovups(mem(rbx, 0*32), ymm0) - vmovups(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm1, ymm2, ymm5) - vfmadd231ps(ymm0, ymm3, ymm6) - vfmadd231ps(ymm1, ymm3, ymm7) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm1, ymm2, ymm9) - vfmadd231ps(ymm0, ymm3, ymm10) - vfmadd231ps(ymm1, ymm3, ymm11) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm1, ymm2, ymm13) - vfmadd231ps(ymm0, ymm3, ymm14) - vfmadd231ps(ymm1, ymm3, ymm15) - - dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. - - label(.SPOSTACCUM) - - mov(r12, rcx) // reset rcx to current utile of c. - mov(var(alpha), rax) // load address of alpha - mov(var(beta), rbx) // load address of beta - vbroadcastss(mem(rax), ymm0) // load alpha and duplicate - vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - - vmulps(ymm0, ymm4, ymm4) // scale by alpha - vmulps(ymm0, ymm5, ymm5) - vmulps(ymm0, ymm6, ymm6) - vmulps(ymm0, ymm7, ymm7) - vmulps(ymm0, ymm8, ymm8) - vmulps(ymm0, ymm9, ymm9) - vmulps(ymm0, ymm10, ymm10) - vmulps(ymm0, ymm11, ymm11) - vmulps(ymm0, ymm12, ymm12) - vmulps(ymm0, ymm13, ymm13) - vmulps(ymm0, ymm14, ymm14) - vmulps(ymm0, ymm15, ymm15) - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - // now avoid loading C if beta == 0 - vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomiss(xmm0, xmm3) // set ZF if beta == 0. - je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - label(.SROWSTORED) - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) - vmovups(ymm5, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) - vmovups(ymm7, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) - vmovups(ymm9, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm11) - vmovups(ymm11, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm13) - vmovups(ymm13, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - - vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm15) - vmovups(ymm15, mem(rcx, rsi, 8)) - //add(rdi, rcx) - - - jmp(.SDONE) // jump to end. - - - label(.SCOLSTORED) + uint64_t n_left = n0 % 16; + + /* For row storage format, kernel is re-written to */ + /* use mask load/store instruction */ + if ( n_left && (rs_c0 != 1)) + { + float* restrict cij = c; + float* restrict bj = b; + float* restrict ai = a; + /**************************************************************************/ + /* Mask load and store support is added for fringe cases */ + /* Fringe cases are the numbers which not multiple of xmm or ymm register */ + /* n_left : 15,14,13,11,10,9,7,6,5,3 */ + /* When mask register values are set, load/store is performed */ + /* When mask register values are not set, load/store is not performed */ + /*Elements: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16*/ + /*n0=16 : -----------ymm--------------- -----------ymm---------------- */ + /*n0=15 : -----------ymm--------------- -1 -1 -1 -1 -1 -1 -1 0 */ + /*n0=14 : -----------ymm--------------- -1 -1 -1 -1 -1 -1 0 0 */ + /*n0=9 : -----------ymm--------------- -1 0 0 0 0 0 0 0 */ + /*n0=8 : -----------ymm--------------- -----------Not used--------- */ + /*n0=7 : -1 -1 -1 -1 -1 -1 -1 0 -----------Not used--------- */ + /*n0=3 : -1 -1 -1 0 0 0 0 0 -----------Not used--------- */ + /*Same code can be resued for multiple n_left by just varing mask register*/ + /*We will be able to perform complete operation of tile with this approach*/ + /**************************************************************************/ + switch(n_left) + { + /*Fringe cases*/ + case 15: case 14: case 13: + case 11: case 10: case 9: + { + const dim_t nr_cur = n_left; + /**********************************************/ + /* These case is executed when nleft - 9 to 15*/ + /* 16 Elements in col order */ + /* ---YMM REG----- ---YMM Mask Reg--- */ + /* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 */ + /*15:0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 x */ + /*14:0 1 2 3 4 5 6 7 8 9 10 11 12 13 x x */ + /*11:0 1 2 3 4 5 6 7 8 9 10 x x x x x */ + /* and so on */ + /**********************************************/ + bli_sgemmsup_rv_zen_asm_6x16m_mask + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 7: case 6: case 5: + { + /***********************************************/ + /* These case is executed when nleft - 5 to 7 */ + /* 8 Elements in col order */ + /* YMM Mask REG */ + /* 0 1 2 3 4 5 6 7 */ + /*7: 0 1 2 3 4 5 6 x */ + /*6: 0 1 2 3 4 5 x x */ + /*5: 0 1 2 3 4 x x x */ + /**********************************************/ + const dim_t nr_cur = n_left; + + bli_sgemmsup_rv_zen_asm_6x8m_mask + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 3: case 1: + { + /***********************************************/ + /* These case is executed when nleft - 3/1 */ + /* 4 Elements in col order */ + /* XMM Mask REG */ + /* 0 1 2 3 */ + /*3: 0 1 2 x */ + /*1: 0 x x x */ + /**********************************************/ + const dim_t nr_cur = n_left; + bli_sgemmsup_rv_zen_asm_6x4m_mask + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + /*Non-Fringe cases*/ + case 12: + { + #if 0 + const dim_t nr_cur = 12; + bli_sgemmsup_rv_haswell_asm_6x12m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + #endif + + dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + + nr_cur = 4; + bli_sgemmsup_rv_zen_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + + } + case 8: + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 4: + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + case 2: + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + break; + } + default: + break; + } + return; + } + + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 6x?m kernels, as needed. + if (n_left ) + { + float* cij = c; + float* bj = b; + float* ai = a; + + if ( 8 <= n_left ) + { + const dim_t nr_cur = 8; + + bli_sgemmsup_rv_zen_asm_6x8m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 4 <= n_left ) + { + const dim_t nr_cur = 4; + + bli_sgemmsup_rv_zen_asm_6x4m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_sgemmsup_rv_zen_asm_6x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + + if ( 1 == n_left ) + { + dim_t ps_a0 = bli_auxinfo_ps_a( data ); + if ( ps_a0 == 6 * rs_a0 ) + { + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + else + { + const dim_t mr = 6; + + // Since A is packed into row panels, we must use a loop over + // gemv. + dim_t m_iter = ( m0 + mr - 1 ) / mr; + dim_t m_left = m0 % mr; + + float* restrict ai_ii = ai; + float* restrict cij_ii = cij; + + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) + { + dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) + ? mr : m_left ); + + bli_sgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, + alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, + beta, cij_ii, rs_c0, cntx, NULL + ); + cij_ii += mr*rs_c0; ai_ii += ps_a0; + } + } + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(dt) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(dt) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(dt) + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X16I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored pre-fetching on c // not used + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of pre-fetching c + label(.SCOLPFETCH) // column-stored pre-fetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(dt) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for pre-fetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx, 0*32), ymm0) + vmovups(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vfmadd231ps(ymm0, ymm3, ymm6) + vfmadd231ps(ymm1, ymm3, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vfmadd231ps(ymm0, ymm3, ymm10) + vfmadd231ps(ymm1, ymm3, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vfmadd231ps(ymm0, ymm3, ymm14) + vfmadd231ps(ymm1, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(dt) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm5) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm7) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm9) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm11) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm13) + vmovups(ymm13, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + vfmadd231ps(mem(rcx, rsi, 8), ymm3, ymm15) + vmovups(ymm15, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORED) /*|-----------------| |-----|----| - | | | | 8x4 | 8x2| - | 4x8 | 4x8 | | | | - | | | |-----|----| - |-----------------| | 8x4 | 8x2| - | 2x8 | 2x8 | | | | - |------------------ |----------|*/ - - /****6x16 tile is transposed and saved in col major as 6x16*****/ - /****top left tile 4x8 transposed to top left tile 8x4**********/ - vunpcklps(ymm6, ymm4, ymm0)//a0b0a1b1 a4b4a5b5 - vunpcklps(ymm10, ymm8, ymm1)//c0d0c1d1 c4d4c5d5 - vshufps(imm(0x4e), ymm1, ymm0, ymm2)//a1b1c0d0 a5b5c4d4 - vblendps(imm(0xcc), ymm2, ymm0, ymm0)//a0b0c0d0 a4b4c4d4 - vblendps(imm(0x33), ymm2, ymm1, ymm1)//a1b1c1d1 a5b5c5d5 - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c - - /***bottom left tile - 2x8 is transposed to top right tile 8x2**********/ - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) - lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c - - vmovlpd(mem(rax), xmm1, xmm1) - vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) - lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) - - lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c - - /***top right tile 4x8 is transposed to bottom left tile 8x4**********/ - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - /*** bottom right 2x8 is transposed to bottom right tile 8x2*******/ - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) - lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c - - vmovlpd(mem(rax), xmm1, xmm1) - vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) - lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c - vmovlpd(mem(rdx), xmm1, xmm1) - vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) - - jmp(.SDONE) // jump to end. - - label(.SBETAZERO) - - cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - label(.SROWSTORBZ) - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vmovups(ymm6, mem(rcx)) - vmovups(ymm7, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vmovups(ymm8, mem(rcx)) - vmovups(ymm9, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vmovups(ymm10, mem(rcx)) - vmovups(ymm11, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vmovups(ymm12, mem(rcx)) - vmovups(ymm13, mem(rcx, rsi, 8)) - add(rdi, rcx) - - - vmovups(ymm14, mem(rcx)) - vmovups(ymm15, mem(rcx, rsi, 8)) - //add(rdi, rcx) - - jmp(.SDONE) // jump to end. - - - label(.SCOLSTORBZ) - /****6x16 tile going to save into 16x6 tile in C*****/ - /******************top left tile 8x4***************************/ - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c - /******************top right tile 8x2***************************/ - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 1), rdx) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) - lea(mem(rdx, rsi, 1), rdx) - lea(mem(rdx, rsi, 4), rdx) // rdx += 8*cs_c - - /******************bottom left tile 8x4***************************/ - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - - /******************bottom right tile 8x2***************************/ - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 1), rdx) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) - - label(.SDONE) - - lea(mem(r12, rdi, 4), r12) // - lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c - - //lea(mem(r14, r8, 4), r14) // - //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a - mov(var(ps_a4), rax) // load ps_a4 - lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 - - dec(r11) // ii -= 1; - jne(.SLOOP6X16I) // iterate again if ii != 0. - - label(.SRETURN) - + | | | | 8x4 | 8x2| + | 4x8 | 4x8 | | | | + | | | |-----|----| + |-----------------| | 8x4 | 8x2| + | 2x8 | 2x8 | | | | + |------------------ |----------|*/ + + /****6x16 tile is transposed and saved in col major as 6x16*****/ + /****top left tile 4x8 transposed to top left tile 8x4**********/ + vunpcklps(ymm6, ymm4, ymm0)//a0b0a1b1 a4b4a5b5 + vunpcklps(ymm10, ymm8, ymm1)//c0d0c1d1 c4d4c5d5 + vshufps(imm(0x4e), ymm1, ymm0, ymm2)//a1b1c0d0 a5b5c4d4 + vblendps(imm(0xcc), ymm2, ymm0, ymm0)//a0b0c0d0 a4b4c4d4 + vblendps(imm(0x33), ymm2, ymm1, ymm1)//a1b1c1d1 a5b5c5d5 + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + + /***bottom left tile - 2x8 is transposed to top right tile 8x2**********/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + /***top right tile 4x8 is transposed to bottom left tile 8x4**********/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += cs_c + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + /*** bottom right 2x8 is transposed to bottom right tile 8x2*******/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 4), rax) // rax += 4*cs_c + + vmovlpd(mem(rax), xmm1, xmm1) + vmovhpd(mem(rax, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rax)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(rax, rsi, 1)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 2), rdx) // rdx += 2*cs_c + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(rdx, rsi, 1)) // store ( gamma43..gamma53 ) + lea(mem(rdx, rsi, 4), rdx) // rdx += 4*cs_c + vmovlpd(mem(rdx), xmm1, xmm1) + vmovhpd(mem(rdx, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(rdx)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(rdx, rsi, 1)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case + + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx, rsi, 8)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx, rsi, 8)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + + label(.SCOLSTORBZ) + /****6x16 tile going to save into 16x6 tile in C*****/ + /******************top left tile 8x4***************************/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + lea(mem(rcx, rsi, 4), rcx) // rcx += 4*cs_c + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + lea(mem(rdx, rsi, 1), rdx) + lea(mem(rdx, rsi, 4), rdx) // rdx += 8*cs_c + + /******************bottom left tile 8x4***************************/ + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + /******************bottom right tile 8x2***************************/ + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X16I) // iterate again if ii != 0. + + label(.SRETURN) + end_asm( - : // output operands (none) - : // input operands + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), - [ps_a4] "m" (ps_a4), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -888,53 +1053,53 @@ void bli_sgemmsup_rv_zen_asm_6x16m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "ymm0", "ymm1", "ymm2", "ymm3", - "ymm4", "ymm5", "ymm6", "ymm7", - "ymm8", "ymm9", "ymm10", "ymm11", - "ymm12", "ymm13", "ymm14", "ymm15", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 16; - const dim_t i_edge = m0 - ( dim_t )m_left; - - float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + m_iter*ps_a; - float* restrict bj = b; - - sgemmsup_ker_ft ker_fps[6] = - { - NULL, - bli_sgemmsup_rv_zen_asm_1x16, - bli_sgemmsup_rv_zen_asm_2x16, - bli_sgemmsup_rv_zen_asm_3x16, - bli_sgemmsup_rv_zen_asm_4x16, - bli_sgemmsup_rv_zen_asm_5x16 - }; - - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - - } + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", + "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 16; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x16, + bli_sgemmsup_rv_zen_asm_2x16, + bli_sgemmsup_rv_zen_asm_3x16, + bli_sgemmsup_rv_zen_asm_4x16, + bli_sgemmsup_rv_zen_asm_5x16 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + + } } void bli_sgemmsup_rv_zen_asm_6x8m @@ -953,479 +1118,479 @@ void bli_sgemmsup_rv_zen_asm_6x8m cntx_t* restrict cntx ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a4 = ps_a * sizeof( float ); - - if ( m_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) - lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - // skylake can execute 3 vxorpd ipc with - // a latency of 1 cycle, while vzeroall - // has a latency of 12 cycles. - vxorps(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower - vxorps(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us down. - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm14, ymm14, ymm14) - - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored prefetching on c - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c - - jmp(.SPOSTPFETCH) // jump to end of prefetching c - label(.SCOLPFETCH) // column-stored prefetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c - - label(.SPOSTPFETCH) // done prefetching c - - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - - - label(.SLOOPKITER) // MAIN LOOP - - - // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) - - vmovups(mem(rbx), ymm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm0, ymm3, ymm6) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm0, ymm3, ymm10) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm0, ymm3, ymm14) - - // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - - vmovups(mem(rbx), ymm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm0, ymm3, ymm6) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm0, ymm3, ymm10) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm0, ymm3, ymm14) - - // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) - - vmovups(mem(rbx), ymm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm0, ymm3, ymm6) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm0, ymm3, ymm10) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm0, ymm3, ymm14) - - // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - - vmovups(mem(rbx), ymm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm0, ymm3, ymm6) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm0, ymm3, ymm10) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm0, ymm3, ymm14) - - - dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. - - - label(.SCONSIDKLEFT) - - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - label(.SLOOPKLEFT) // EDGE LOOP - - vmovups(mem(rbx), ymm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), ymm2) - vbroadcastss(mem(rax, r8, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm4) - vfmadd231ps(ymm0, ymm3, ymm6) - - vbroadcastss(mem(rax, r8, 2), ymm2) - vbroadcastss(mem(rax, r13, 1), ymm3) - vfmadd231ps(ymm0, ymm2, ymm8) - vfmadd231ps(ymm0, ymm3, ymm10) - - vbroadcastss(mem(rax, r8, 4), ymm2) - vbroadcastss(mem(rax, r15, 1), ymm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(ymm0, ymm2, ymm12) - vfmadd231ps(ymm0, ymm3, ymm14) - - - dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. - - - label(.SPOSTACCUM) - - mov(r12, rcx) // reset rcx to current utile of c. - mov(var(alpha), rax) // load address of alpha - mov(var(beta), rbx) // load address of beta - vbroadcastss(mem(rax), ymm0) // load alpha and duplicate - vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - - vmulps(ymm0, ymm4, ymm4) // scale by alpha - vmulps(ymm0, ymm6, ymm6) - vmulps(ymm0, ymm8, ymm8) - vmulps(ymm0, ymm10, ymm10) - vmulps(ymm0, ymm12, ymm12) - vmulps(ymm0, ymm14, ymm14) - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - // now avoid loading C if beta == 0 - - vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomiss(xmm0, xmm3) // set ZF if beta == 0. - je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORED) // jump to column storage case - - - label(.SROWSTORED) - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORED) - - /****6x8 tile is transposed and saved in col major as 8x6*****/ - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm1) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vpermilps(imm(0xe),xmm0,xmm5) - vpermilps(imm(0xe),xmm2,xmm6) - vmovq(mem(rdx),xmm4) - vmovq(mem(rdx, rsi, 4),xmm1) - vfmadd231ps(xmm4, xmm3, xmm0) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) - lea(mem(rdx, rsi, 1), rdx) - vmovq(mem(rdx),xmm4) - vmovq(mem(rdx, rsi, 4),xmm1) - vfmadd231ps(xmm4, xmm3, xmm5) - vfmadd231ps(xmm1, xmm3, xmm6) - vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) - vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 1), rdx) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vpermilps(imm(0xe),xmm0,xmm5) - vpermilps(imm(0xe),xmm2,xmm6) - vmovq(mem(rdx),xmm4) - vmovq(mem(rdx, rsi, 4),xmm1) - vfmadd231ps(xmm4, xmm3, xmm0) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) - lea(mem(rdx, rsi, 1), rdx) - vmovq(mem(rdx),xmm4) - vmovq(mem(rdx, rsi, 4),xmm1) - vfmadd231ps(xmm4, xmm3, xmm5) - vfmadd231ps(xmm1, xmm3, xmm6) - vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) - vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) - - jmp(.SDONE) // jump to end. - - label(.SBETAZERO) - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case - - label(.SROWSTORBZ) - - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORBZ) - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) - - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) - /******************top right tile 8x2***************************/ - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) - lea(mem(rdx, rsi, 1), rdx) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) - vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) - - - label(.SDONE) - - lea(mem(r12, rdi, 4), r12) // - lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c - - //lea(mem(r14, r8, 4), r14) // - //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a - mov(var(ps_a4), rax) // load ps_a4 - lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 - - dec(r11) // ii -= 1; - jne(.SLOOP6X8I) // iterate again if ii != 0. - - label(.SRETURN) + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorps(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us down. + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx), ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vbroadcastss(mem(rax, r8, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm0, ymm3, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vbroadcastss(mem(rax, r13, 1), ymm3) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm0, ymm3, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vbroadcastss(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm0, ymm3, ymm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm3) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x8 tile is transposed and saved in col major as 8x6*****/ + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm1) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilps(imm(0xe),xmm0,xmm5) + vpermilps(imm(0xe),xmm2,xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vpermilps(imm(0xe),xmm0,xmm5) + vpermilps(imm(0xe),xmm2,xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) + vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) + vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx)) + add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + add(rdi, rcx) + vmovups(ymm8, mem(rcx)) + add(rdi, rcx) + vmovups(ymm10, mem(rcx)) + add(rdi, rcx) + vmovups(ymm12, mem(rcx)) + add(rdi, rcx) + vmovups(ymm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma05..gamma35 ) + + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma06..gamma36 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) + /******************top right tile 8x2***************************/ + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) + lea(mem(rdx, rsi, 1), rdx) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + vmovhpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) + + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X8I) // iterate again if ii != 0. + + label(.SRETURN) end_asm( - : // output operands (none) - : // input operands + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), - [ps_a4] "m" (ps_a4), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -1436,51 +1601,51 @@ void bli_sgemmsup_rv_zen_asm_6x8m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "ymm0", "ymm1", "ymm2", "ymm3", - "ymm4", "ymm6", "ymm8", "ymm10", - "ymm12", "ymm14", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 8; - const dim_t i_edge = m0 - ( dim_t )m_left; - - float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + m_iter*ps_a; - float* restrict bj = b; - - sgemmsup_ker_ft ker_fps[6] = - { - NULL, - bli_sgemmsup_rv_zen_asm_1x8, - bli_sgemmsup_rv_zen_asm_2x8, - bli_sgemmsup_rv_zen_asm_3x8, - bli_sgemmsup_rv_zen_asm_4x8, - bli_sgemmsup_rv_zen_asm_5x8 - }; - - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - } + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "ymm0", "ymm1", "ymm2", "ymm3", + "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "ymm14", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x8, + bli_sgemmsup_rv_zen_asm_2x8, + bli_sgemmsup_rv_zen_asm_3x8, + bli_sgemmsup_rv_zen_asm_4x8, + bli_sgemmsup_rv_zen_asm_5x8 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } } void bli_sgemmsup_rv_zen_asm_6x4m @@ -1499,420 +1664,420 @@ void bli_sgemmsup_rv_zen_asm_6x4m cntx_t* restrict cntx ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a4 = ps_a * sizeof( float ); - - if ( m_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) - lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - - // During preamble and loops: - // r12 = rcx = c // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.SLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - vxorps(xmm1, xmm1, xmm1) - vxorps(xmm4, xmm4, xmm4) - vxorps(xmm6, xmm6, xmm6) - vxorps(xmm8, xmm8, xmm8) - vxorps(xmm10, xmm10, xmm10) - vxorps(xmm12, xmm12, xmm12) - vxorps(xmm14, xmm14, xmm14) - - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored prefetching on c - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c - - jmp(.SPOSTPFETCH) // jump to end of prefetching c - label(.SCOLPFETCH) // column-stored prefetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c - - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP - - // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) - - vmovups(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - - vmovups(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) - - vmovups(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - - vmovups(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. - - label(.SCONSIDKLEFT) - - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - label(.SLOOPKLEFT) // EDGE LOOP - - vmovups(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - - dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. - - - label(.SPOSTACCUM) - - mov(r12, rcx) // reset rcx to current utile of c. - mov(var(alpha), rax) // load address of alpha - mov(var(beta), rbx) // load address of beta - vbroadcastss(mem(rax), xmm0) // load alpha and duplicate - vbroadcastss(mem(rbx), xmm3) // load beta and duplicate - - vmulps(xmm0, xmm4, xmm4) // scale by alpha - vmulps(xmm0, xmm6, xmm6) - vmulps(xmm0, xmm8, xmm8) - vmulps(xmm0, xmm10, xmm10) - vmulps(xmm0, xmm12, xmm12) - vmulps(xmm0, xmm14, xmm14) - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - // now avoid loading C if beta == 0 - - vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. - vucomiss(xmm0, xmm3) // set ZF if beta == 0. - je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORED) // jump to column storage case - - - label(.SROWSTORED) - - vfmadd231ps(mem(rcx), xmm3, xmm4) - vmovups(xmm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) - vmovups(xmm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm8) - vmovups(xmm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm10) - vmovups(xmm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm12) - vmovups(xmm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm14) - vmovups(xmm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORED) - - /****6x4 tile is transposed and saved in col major as 4x6*****/ - vunpcklps(xmm6, xmm4, xmm0) - vunpcklps(xmm10, xmm8, xmm1) - vshufps(imm(0x4e), xmm1, xmm0, xmm2) - vblendps(imm(0xcc), xmm2, xmm0, xmm0) - vblendps(imm(0x33), xmm2, xmm1, xmm1) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vfmadd231ps(mem(rcx), xmm3, xmm1) - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - - vunpckhps(xmm6, xmm4, xmm0) - vunpckhps(xmm10, xmm8, xmm1) - vshufps(imm(0x4e), xmm1, xmm0, xmm2) - vblendps(imm(0xcc), xmm2, xmm0, xmm0) - vblendps(imm(0x33), xmm2, xmm1, xmm1) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vfmadd231ps(mem(rcx), xmm3, xmm1) - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - - vunpcklps(xmm14, xmm12, xmm0) - vpermilps(imm(0x4e), xmm0, xmm5) + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovups(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x4 tile is transposed and saved in col major as 4x6*****/ + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm1) + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm1) + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + + vunpcklps(xmm14, xmm12, xmm0) + vpermilps(imm(0x4e), xmm0, xmm5) vmovq(mem(rdx),xmm4) - vfmadd231ps(xmm4, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - - lea(mem(rdx, rsi, 1), rdx) - vmovq(mem(rdx),xmm4) - vfmadd231ps(xmm4, xmm3, xmm5) - vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) - - lea(mem(rdx, rsi, 1), rdx) - vunpckhps(xmm14, xmm12, xmm0) - vpermilps(imm(0x4e), xmm0, xmm5) - vmovq(mem(rdx),xmm4) - vfmadd231ps(xmm4, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - - lea(mem(rdx, rsi, 1), rdx) - vmovq(mem(rdx),xmm4) - vfmadd231ps(xmm4, xmm3, xmm5) - vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) - - jmp(.SDONE) // jump to end. - - label(.SBETAZERO) - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case - - label(.SROWSTORBZ) - - vmovups(xmm4, mem(rcx)) - add(rdi, rcx) - vmovups(xmm6, mem(rcx)) - add(rdi, rcx) - vmovups(xmm8, mem(rcx)) - add(rdi, rcx) - vmovups(xmm10, mem(rcx)) - add(rdi, rcx) - vmovups(xmm12, mem(rcx)) - add(rdi, rcx) - vmovups(xmm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORBZ) - - vunpcklps(xmm6, xmm4, xmm0) - vunpcklps(xmm10, xmm8, xmm1) - vshufps(imm(0x4e), xmm1, xmm0, xmm2) - vblendps(imm(0xcc), xmm2, xmm0, xmm0) - vblendps(imm(0x33), xmm2, xmm1, xmm1) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vunpckhps(xmm6, xmm4, xmm0) - vunpckhps(xmm10, xmm8, xmm1) - vshufps(imm(0x4e), xmm1, xmm0, xmm2) - vblendps(imm(0xcc), xmm2, xmm0, xmm0) - vblendps(imm(0x33), xmm2, xmm1, xmm1) - vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) - - vunpcklps(xmm14, xmm12, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) - lea(mem(rdx, rsi, 1), rdx) - vunpckhps(xmm14, xmm12, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) - - label(.SDONE) - - lea(mem(r12, rdi, 4), r12) // - lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c - - //lea(mem(r14, r8, 4), r14) // - //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a - mov(var(ps_a4), rax) // load ps_a4 - lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 - - dec(r11) // ii -= 1; - jne(.SLOOP6X4I) // iterate again if ii != 0. - - label(.SRETURN) + vfmadd231ps(xmm4, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + + lea(mem(rdx, rsi, 1), rdx) + vunpckhps(xmm14, xmm12, xmm0) + vpermilps(imm(0x4e), xmm0, xmm5) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(xmm4, mem(rcx)) + add(rdi, rcx) + vmovups(xmm6, mem(rcx)) + add(rdi, rcx) + vmovups(xmm8, mem(rcx)) + add(rdi, rcx) + vmovups(xmm10, mem(rcx)) + add(rdi, rcx) + vmovups(xmm12, mem(rcx)) + add(rdi, rcx) + vmovups(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0) + vunpcklps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovups(xmm1, mem(rcx)) // store ( gamma01..gamma31 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vunpckhps(xmm6, xmm4, xmm0) + vunpckhps(xmm10, xmm8, xmm1) + vshufps(imm(0x4e), xmm1, xmm0, xmm2) + vblendps(imm(0xcc), xmm2, xmm0, xmm0) + vblendps(imm(0x33), xmm2, xmm1, xmm1) + vmovups(xmm0, mem(rcx)) // store ( gamma02..gamma32 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovups(xmm1, mem(rcx)) // store ( gamma03..gamma33 ) + + vunpcklps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + lea(mem(rdx, rsi, 1), rdx) + vunpckhps(xmm14, xmm12, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma43..gamma53 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X4I) // iterate again if ii != 0. + + label(.SRETURN) end_asm( - : // output operands (none) - : // input operands + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), - [ps_a4] "m" (ps_a4), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -1923,48 +2088,48 @@ void bli_sgemmsup_rv_zen_asm_6x4m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 4; - const dim_t i_edge = m0 - ( dim_t )m_left; - - float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + m_iter*ps_a; - float* restrict bj = b; - - sgemmsup_ker_ft ker_fps[6] = - { - NULL, - bli_sgemmsup_rv_zen_asm_1x4, - bli_sgemmsup_rv_zen_asm_2x4, - bli_sgemmsup_rv_zen_asm_3x4, - bli_sgemmsup_rv_zen_asm_4x4, - bli_sgemmsup_rv_zen_asm_5x4 - }; - - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - } + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x4, + bli_sgemmsup_rv_zen_asm_2x4, + bli_sgemmsup_rv_zen_asm_3x4, + bli_sgemmsup_rv_zen_asm_4x4, + bli_sgemmsup_rv_zen_asm_5x4 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } } void bli_sgemmsup_rv_zen_asm_6x2m @@ -1983,386 +2148,386 @@ void bli_sgemmsup_rv_zen_asm_6x2m cntx_t* restrict cntx ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a4 = ps_a * sizeof( float ); - - if ( m_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) - lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) - - lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a - lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - - // During preamble and loops: - // r12 = rcx = c // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.SLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - vxorps(xmm1, xmm1, xmm1) - vxorps(xmm4, xmm4, xmm4) - vxorps(xmm6, xmm6, xmm6) - vxorps(xmm8, xmm8, xmm8) - vxorps(xmm10, xmm10, xmm10) - vxorps(xmm12, xmm12, xmm12) - vxorps(xmm14, xmm14, xmm14) - - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLPFETCH) // jump to column storage case - label(.SROWPFETCH) // row-stored prefetching on c - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c - - jmp(.SPOSTPFETCH) // jump to end of prefetching c - label(.SCOLPFETCH) // column-stored prefetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c - - label(.SPOSTPFETCH) // done prefetching c - - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; - lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. - label(.SLOOPKITER) // MAIN LOOP - - // ---------------------------------- iteration 0 - prefetch(0, mem(rdx, 5*8)) - vmovq(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 1 - prefetch(0, mem(rdx, r9, 1, 5*8)) - - vmovq(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 2 - prefetch(0, mem(rdx, r9, 2, 5*8)) - - vmovq(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - // ---------------------------------- iteration 3 - prefetch(0, mem(rdx, rcx, 1, 5*8)) - - vmovq(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - dec(rsi) // i -= 1; - jne(.SLOOPKITER) // iterate again if i != 0. - - label(.SCONSIDKLEFT) - - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. - - label(.SLOOPKLEFT) // EDGE LOOP - - vmovq(mem(rbx), xmm0) - add(r10, rbx) // b += rs_b; - - vbroadcastss(mem(rax ), xmm2) - vbroadcastss(mem(rax, r8, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm4) - vfmadd231ps(xmm0, xmm3, xmm6) - - vbroadcastss(mem(rax, r8, 2), xmm2) - vbroadcastss(mem(rax, r13, 1), xmm3) - vfmadd231ps(xmm0, xmm2, xmm8) - vfmadd231ps(xmm0, xmm3, xmm10) - - vbroadcastss(mem(rax, r8, 4), xmm2) - vbroadcastss(mem(rax, r15, 1), xmm3) - add(r9, rax) // a += cs_a; - vfmadd231ps(xmm0, xmm2, xmm12) - vfmadd231ps(xmm0, xmm3, xmm14) - - - dec(rsi) // i -= 1; - jne(.SLOOPKLEFT) // iterate again if i != 0. - - - label(.SPOSTACCUM) - - mov(r12, rcx) // reset rcx to current utile of c. - mov(var(alpha), rax) // load address of alpha - mov(var(beta), rbx) // load address of beta - vbroadcastss(mem(rax), xmm0) // load alpha and duplicate - vbroadcastss(mem(rbx), xmm3) // load beta and duplicate - - vmulps(xmm0, xmm4, xmm4) // scale by alpha - vmulps(xmm0, xmm6, xmm6) - vmulps(xmm0, xmm8, xmm8) - vmulps(xmm0, xmm10, xmm10) - vmulps(xmm0, xmm12, xmm12) - vmulps(xmm0, xmm14, xmm14) - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) - - lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; - lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - // now avoid loading C if beta == 0 - - vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. - vucomiss(xmm0, xmm3) // set ZF if beta == 0. - je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORED) // jump to column storage case - - - label(.SROWSTORED) - - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm4) - vmovlpd(xmm4, mem(rcx)) - add(rdi, rcx) - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm6) - vmovlpd(xmm6, mem(rcx)) - add(rdi, rcx) - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm8) - vmovlpd(xmm8, mem(rcx)) - add(rdi, rcx) - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm10) - vmovlpd(xmm10, mem(rcx)) - add(rdi, rcx) - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm12) - vmovlpd(xmm12, mem(rcx)) - add(rdi, rcx) - vmovsd(mem(rcx), xmm0) - vfmadd231ps(xmm0, xmm3, xmm14) - vmovlpd(xmm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORED) - - /****6x2 tile is transposed and saved in col major as 2x6*****/ - vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 - vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 - vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 - vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 - - vfmadd231ps(mem(rcx), xmm3, xmm2) - vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vfmadd231ps(mem(rcx), xmm3, xmm4) - vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) - - vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 - vpermilps(imm(0x4e),xmm0,xmm5) - vmovq(mem(rdx), xmm4) - vfmadd231ps(xmm4, xmm3, xmm0) - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - lea(mem(rdx, rsi, 1), rdx) - vmovq(mem(rdx), xmm4) - vfmadd231ps(xmm4, xmm3, xmm5) - vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) - - jmp(.SDONE) // jump to end. - - label(.SBETAZERO) - - cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. - jz(.SCOLSTORBZ) // jump to column storage case - - label(.SROWSTORBZ) - - vmovlpd(xmm4, mem(rcx)) - add(rdi, rcx) - vmovlpd(xmm6, mem(rcx)) - add(rdi, rcx) - vmovlpd(xmm8, mem(rcx)) - add(rdi, rcx) - vmovlpd(xmm10, mem(rcx)) - add(rdi, rcx) - vmovlpd(xmm12, mem(rcx)) - add(rdi, rcx) - vmovlpd(xmm14, mem(rcx)) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORBZ) - - vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 - vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 - vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 - vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 - - vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) - lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c - vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) - - vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 - vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) - lea(mem(rdx, rsi, 1), rdx) - vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) - - label(.SDONE) - - lea(mem(r12, rdi, 4), r12) // - lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c - - //lea(mem(r14, r8, 4), r14) // - //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a - mov(var(ps_a4), rax) // load ps_a4 - lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 - - dec(r11) // ii -= 1; - jne(.SLOOP6X2I) // iterate again if ii != 0. - - label(.SRETURN) + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmovq(mem(rbx), xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) + vmovlpd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmovlpd(xmm8, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) + vmovlpd(xmm10, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) + vmovlpd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /****6x2 tile is transposed and saved in col major as 2x6*****/ + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 + vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 + + vfmadd231ps(mem(rcx), xmm3, xmm2) + vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) + + vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 + vpermilps(imm(0x4e),xmm0,xmm5) + vmovq(mem(rdx), xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovq(mem(rdx), xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) + vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm8, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm10, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + vunpcklps(xmm6, xmm4, xmm0)//a0b0a1b1 + vunpcklps(xmm10, xmm8, xmm1)//c0d0c1d1 + vshufps(imm(0x44), xmm1, xmm0, xmm2) //01-00-01-00 + vshufps(imm(0xee), xmm1, xmm0, xmm4) //11-10-11-10 + + vmovupd(xmm2, mem(rcx)) // store ( gamma00..gamma30 ) + lea(mem(rcx, rsi, 1), rcx) // rcx += 1*cs_c + vmovupd(xmm4, mem(rcx)) // store ( gamma01..gamma31 ) + + vunpcklps(xmm14, xmm12, xmm0)//eof0e1f1 + vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) + lea(mem(rdx, rsi, 1), rdx) + vmovhpd(xmm0, mem(rdx)) // store ( gamma41..gamma51 ) + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + //lea(mem(r14, r8, 4), r14) // + //lea(mem(r14, r8, 2), r14) // a_ii = r14 += 6*rs_a + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X2I) // iterate again if ii != 0. + + label(.SRETURN) end_asm( - : // output operands (none) - : // input operands + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), - [ps_a4] "m" (ps_a4), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -2373,46 +2538,1533 @@ void bli_sgemmsup_rv_zen_asm_6x2m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 2; - const dim_t i_edge = m0 - ( dim_t )m_left; - - float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + m_iter*ps_a; - float* restrict bj = b; - - sgemmsup_ker_ft ker_fps[6] = - { - NULL, - bli_sgemmsup_rv_zen_asm_1x2, - bli_sgemmsup_rv_zen_asm_2x2, - bli_sgemmsup_rv_zen_asm_3x2, - bli_sgemmsup_rv_zen_asm_4x2, - bli_sgemmsup_rv_zen_asm_5x2 - }; - - sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - } + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 2; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x2, + bli_sgemmsup_rv_zen_asm_2x2, + bli_sgemmsup_rv_zen_asm_3x2, + bli_sgemmsup_rv_zen_asm_4x2, + bli_sgemmsup_rv_zen_asm_5x2 + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} + +/* Mask elements to specify how many elements to be loaded from C buffer */ +static const int32_t mask[8][8] = { {0, 0, 0, 0, 0, 0, 0, 0}, //load no values, not used currently + {-1, 0, 0, 0, 0, 0, 0, 0}, // load 1 value from memory + {-1, -1, 0, 0, 0, 0, 0, 0}, // load 2 values from memory + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + }; + +void bli_sgemmsup_rv_zen_asm_6x16m_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // This kernel is called when n_left is greater than 8. This kernel operates 16 columns at time. + // First 8 elements can be loaded directly and next elements will be loaded based the mask reg + // + // Sets up the mask for loading relevant remainder elements in load direction + // + // ______ymm0______ __________ymm1_________ + // | | | | | | | | | | | | | | | | | | + // |0|1|2|3|4|5|6|7| |8|9|10|11|12|13|14|15| ----> Source vector + // |_|_|_|_|_|_|_|_| |_|_|__|__|__|__|__|__| + // + // ________________ ______ymm3_______ + // | | | | | | | | | | | | | | | | | | + // |NoMASK Required| |x|x|x|x|x|x|x|x| ----> Mask vector[x can be -1/0] + // |_|_|_|_|_|_|_|_| |_|_|_|_|_|_|_|_| + // + // For example when n_left = 13 + // ________________ ________ymm3__________ + // | | | | | | | | | | | | | | | | | | + // |NoMASK Required| |-1|-1|-1|-1|-1|0|0|0| ----> Mask vector + // |_|_|_|_|_|_|_|_| |__|__|__|__|__|_|_|_| + // + // ______ymm0_______ ________ymm1__________ + // | | | | | | | | | | | | | | | | | | + // |0|1|2|3|4|5|6|7| |8|9|10|11|12|0 |0 |0 | ----> Destination vector + // |_|_|_|_|_|_|_|_| |_|_|__|__|__|__|__|__| + // + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask values + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X15I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm5, ymm5, ymm5) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm7, ymm7, ymm7) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm9, ymm9, ymm9) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm11, ymm11, ymm11) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm13, ymm13, ymm13) + vxorps(ymm14, ymm14, ymm14) + vxorps(ymm15, ymm15, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 8*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 8*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 8*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 8*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 8*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) //load first 8 elements + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) //load next required elements + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + vfmadd231ps(ymm1, ymm2, ymm15) + + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + vfmadd231ps(ymm1, ymm2, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + vfmadd231ps(ymm1, ymm2, ymm15) + + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + vfmadd231ps(ymm1, ymm2, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + vfmadd231ps(ymm1, ymm2, ymm15) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm1) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + vmulps(ymm0, ymm14, ymm14) + vmulps(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) //store only required elements + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm7) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm9) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm11) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm13) + vmaskmovps(ymm13, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm14) + vmovups(ymm14, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm15) + vmaskmovps(ymm15, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support */ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) //Store only required elements + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx, 0*32)) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx, 0*32)) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx, 0*32)) + vmaskmovps(ymm13, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm14, mem(rcx, 0*32)) + vmaskmovps(ymm15, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X15I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", + "ymm13", "ymm14", "ymm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x16_mask, + bli_sgemmsup_rv_zen_asm_2x16_mask, + bli_sgemmsup_rv_zen_asm_3x16_mask, + bli_sgemmsup_rv_zen_asm_4x16_mask, + bli_sgemmsup_rv_zen_asm_5x16_mask + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_zen_asm_6x8m_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // This kernel is called when n_left 7, 6, 5. This kernel operates 8 columns at time. + // + // Sets up the mask for loading relevant remainder elements in load direction + // + // ______ymm0_______ + // | | | | | | | | | + // |0|1|2|3|4|5|6|7| ----> Source vector + // |_|_|_|_|_|_|_|_| + // + //______ymm3_______ + //| | | | | | | | | + //|x|x|x|x|x|x|x|x| ----> Mask vector[x can be -1/0] + //|_|_|_|_|_|_|_|_| + // + // For example when n_left = 6 + // ________ymm3__________ + // | | | | | | | | | + // |-1|-1|-1|-1|-1|-1|0|0| ----> Mask vector + // |__|__|__|__|__|__|_|_| + // + // _______ymm0______ + // | | | | | | | | | + // |0|1|2|3|4|5|0|0| ----> Destination vector + // |_|_|_|_|_|_|_|_| + // + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask values + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X7I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + // skylake can execute 3 vxorps ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorps(ymm4, ymm4, ymm4) + vxorps(ymm6, ymm6, ymm6) + vxorps(ymm8, ymm8, ymm8) + vxorps(ymm10, ymm10, ymm10) + vxorps(ymm12, ymm12, ymm12) + vxorps(ymm14, ymm14, ymm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 4*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 4*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 4*4)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 4*4)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) //load required elements + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vbroadcastss(mem(rax, r15, 1), ymm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(ymm0, ymm2, ymm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm1) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm6) + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm8) + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm10) + vmaskmovps(ymm10, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm12) + vmaskmovps(ymm12, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm14) + vmaskmovps(ymm14, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm10, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm12, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm14, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X7I) // iterate again if ii != 0. + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", + "ymm8", "ymm10", "ymm12", "ymm14", + "memory" + ) + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if (m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter * ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x8_mask, + bli_sgemmsup_rv_zen_asm_2x8_mask, + bli_sgemmsup_rv_zen_asm_3x8_mask, + bli_sgemmsup_rv_zen_asm_4x8_mask, + bli_sgemmsup_rv_zen_asm_5x8_mask + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + + return; + } +} + +void bli_sgemmsup_rv_zen_asm_6x4m_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + // This kernel is called when n_left is 3/1. This kernel operates 4 columns at time. + // + // Sets up the mask for loading relevant remainder elements in load direction + // + // __xmm0___ + // | | | | | + // |0|1|2|3| ----> Source vector + // |_|_|_|_| + // + // __xmm7___ + // | | | | | + // |x|x|x|x| ----> Mask vector[x can be -1/0] + // |_|_|_|_| + // + // For example when n_left = 3 + // ___xmm7_____ + // | | | | | + // |-1|-1|-1|0| ----> Mask vector + // |__|__|__|_| + // + // For example when n_left = 1 + // ___xmm7___ + // | | | | | + // |-1|0|0|0| ----> Mask vector + // |__|_|_|_| + // + // __xmm0___ + // | | | | | + // |0|1|2|3| ----> Destination vector + // |_|_|_|_| + // + const int32_t *mask_vec = mask[n0]; + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask values + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + // During preamble and loops: + // r12 = rcx = c // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.SLOOP6X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + vxorps(xmm14, xmm14, xmm14) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 0)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 0)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 0)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 0)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 0)) // prefetch c + 5*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + vbroadcastss(mem(rax, r15, 1), xmm3) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + vfmadd231ps(xmm0, xmm3, xmm14) + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + vmulps(xmm0, xmm14, xmm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm6) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm10) + vmaskmovps(xmm10, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmaskmovps(xmm12, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm14) + vmaskmovps(xmm14, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm10, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm12, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm14, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + lea(mem(r12, rdi, 4), r12) // + lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c + + mov(var(ps_a4), rax) // load ps_a4 + lea(mem(r14, rax, 1), r14) // a_ii = r14 += ps_a4 + + dec(r11) // ii -= 1; + jne(.SLOOP6X4I) // iterate again if ii != 0. + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm6", "xmm7", + "xmm8", "xmm10", "xmm12", "xmm14", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = n0; + const dim_t i_edge = m0 - ( dim_t )m_left; + + float* restrict cij = c + i_edge*rs_c; + float* restrict ai = a + m_iter*ps_a; + float* restrict bj = b; + + sgemmsup_ker_ft ker_fps[6] = + { + NULL, + bli_sgemmsup_rv_zen_asm_1x4_mask, + bli_sgemmsup_rv_zen_asm_2x4_mask, + bli_sgemmsup_rv_zen_asm_3x4_mask, + bli_sgemmsup_rv_zen_asm_4x4_mask, + bli_sgemmsup_rv_zen_asm_5x4_mask + }; + + sgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } } diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16n.c index eb690e9f6c..3c77fda899 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c index 298ede7204..bc351ce956 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index b12f67ca9d..4e90b444d5 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c index 03c1627f15..b39b091753 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -548,7 +548,8 @@ void bli_cgemmsup_rv_zen_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) } @@ -910,7 +911,7 @@ void bli_cgemmsup_rv_zen_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -1286,7 +1287,7 @@ void bli_cgemmsup_rv_zen_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) } @@ -1604,7 +1605,7 @@ void bli_cgemmsup_rv_zen_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) } diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c index 07fbd26296..d0f86f4ce6 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -739,7 +739,9 @@ void bli_cgemmsup_rv_zen_asm_3x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1230,7 +1232,8 @@ void bli_cgemmsup_rv_zen_asm_3x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "ymm14", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c index 6c68707e18..2e2b888f08 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c index 1638eaba0b..3b2aedc7e2 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -540,7 +540,8 @@ void bli_zgemmsup_rv_zen_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) } @@ -926,7 +927,7 @@ void bli_zgemmsup_rv_zen_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -1314,7 +1315,7 @@ void bli_zgemmsup_rv_zen_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", "memory" ) } @@ -1650,7 +1651,7 @@ void bli_zgemmsup_rv_zen_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "memory" ) } diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c index 898e4006e9..cadba52ce4 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -702,7 +702,9 @@ void bli_zgemmsup_rv_zen_asm_3x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1194,7 +1196,8 @@ void bli_zgemmsup_rv_zen_asm_3x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "ymm14", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c index 872d048685..0d7893e9c7 100644 --- a/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/broken/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c index 96bc927499..c0c4d5f198 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16.c @@ -2,8 +2,10 @@ BLIS An object-based framework for developing high-performance BLAS-like libraries. + Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -328,7 +330,8 @@ void bli_sgemmsup_rd_zen_asm_2x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -559,7 +562,7 @@ void bli_sgemmsup_rd_zen_asm_1x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -857,7 +860,8 @@ void bli_sgemmsup_rd_zen_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } @@ -1087,7 +1091,7 @@ void bli_sgemmsup_rd_zen_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -1353,7 +1357,8 @@ void bli_sgemmsup_rd_zen_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) } void bli_sgemmsup_rd_zen_asm_1x4 @@ -1567,7 +1572,7 @@ void bli_sgemmsup_rd_zen_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) } @@ -1791,7 +1796,7 @@ void bli_sgemmsup_rd_zen_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "memory" ) } @@ -1978,7 +1983,7 @@ void bli_sgemmsup_rd_zen_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "memory" ) } @@ -2369,7 +2374,8 @@ void bli_sgemmsup_rd_zen_asm_6x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) consider_edge_cases: // Handle edge cases in the m dimension, if they exist. @@ -2663,6 +2669,7 @@ void bli_sgemmsup_rd_zen_asm_3x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "memory" ) } diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c index 00773b3b58..7599b26d4e 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -556,7 +556,9 @@ void bli_sgemmsup_rd_zen_asm_6x16m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1035,7 +1037,9 @@ void bli_sgemmsup_rd_zen_asm_6x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1517,7 +1521,9 @@ void bli_sgemmsup_rd_zen_asm_6x4m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1923,7 +1929,9 @@ void bli_sgemmsup_rd_zen_asm_6x2m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c index dfe5ca28af..824189992b 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rd_zen_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -594,7 +594,9 @@ void bli_sgemmsup_rd_zen_asm_6x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1061,7 +1063,9 @@ void bli_sgemmsup_rd_zen_asm_3x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1471,7 +1475,8 @@ void bli_sgemmsup_rd_zen_asm_2x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm7", "ymm8", + "ymm10", "ymm11", "ymm13", "ymm14", "memory" ) consider_edge_cases: @@ -1828,7 +1833,7 @@ void bli_sgemmsup_rd_zen_asm_1x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", "ymm10", "ymm13", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c index 6c9f8cabe1..8915ec8e5d 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -720,7 +720,9 @@ void bli_sgemmsup_rv_zen_asm_5x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1214,7 +1216,9 @@ void bli_sgemmsup_rv_zen_asm_4x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -1772,7 +1776,9 @@ void bli_sgemmsup_rv_zen_asm_3x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -2165,7 +2171,9 @@ void bli_sgemmsup_rv_zen_asm_2x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -2525,7 +2533,9 @@ void bli_sgemmsup_rv_zen_asm_1x16 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -2973,7 +2983,9 @@ void bli_sgemmsup_rv_zen_asm_6x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3426,7 +3438,9 @@ void bli_sgemmsup_rv_zen_asm_5x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -3792,7 +3806,9 @@ void bli_sgemmsup_rv_zen_asm_4x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) } @@ -4204,7 +4220,8 @@ void bli_sgemmsup_rv_zen_asm_3x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -4530,7 +4547,8 @@ void bli_sgemmsup_rv_zen_asm_2x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -4793,7 +4811,8 @@ void bli_sgemmsup_rv_zen_asm_1x8 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -5194,7 +5213,8 @@ void bli_sgemmsup_rv_zen_asm_6x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -5582,7 +5602,8 @@ void bli_sgemmsup_rv_zen_asm_5x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -5920,7 +5941,8 @@ void bli_sgemmsup_rv_zen_asm_4x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -6245,7 +6267,8 @@ void bli_sgemmsup_rv_zen_asm_3x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -6518,7 +6541,8 @@ void bli_sgemmsup_rv_zen_asm_2x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -6772,7 +6796,8 @@ void bli_sgemmsup_rv_zen_asm_1x4 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -7159,7 +7184,8 @@ void bli_sgemmsup_rv_zen_asm_6x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -7532,7 +7558,8 @@ void bli_sgemmsup_rv_zen_asm_5x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -7868,7 +7895,8 @@ void bli_sgemmsup_rv_zen_asm_4x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -8167,7 +8195,8 @@ void bli_sgemmsup_rv_zen_asm_3x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -8427,7 +8456,8 @@ void bli_sgemmsup_rv_zen_asm_2x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", + "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -8663,7 +8693,8 @@ void bli_sgemmsup_rv_zen_asm_1x2 "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm2", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c index 41dbbd699e..31918565b9 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -895,7 +895,9 @@ void bli_sgemmsup_rv_zen_asm_6x16m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1431,7 +1433,8 @@ void bli_sgemmsup_rv_zen_asm_6x8m "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm6", "ymm8", "ymm10", + "ymm12", "ymm14", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c index a7ab770cb2..be8c9b065d 100644 --- a/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c +++ b/kernels/zen/3/sup/other/bli_gemmsup_rv_zen_asm_s6x16n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -855,7 +855,9 @@ void bli_sgemmsup_rv_zen_asm_6x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -1621,7 +1623,8 @@ void bli_sgemmsup_rv_zen_asm_5x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "memory" ) consider_edge_cases: @@ -2230,7 +2233,8 @@ void bli_sgemmsup_rv_zen_asm_4x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "memory" ) consider_edge_cases: @@ -2876,7 +2880,9 @@ void bli_sgemmsup_rv_zen_asm_3x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", + "ymm15", "memory" ) consider_edge_cases: @@ -3366,7 +3372,8 @@ void bli_sgemmsup_rv_zen_asm_2x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", + "ymm11", "ymm12", "memory" ) consider_edge_cases: @@ -3818,7 +3825,7 @@ void bli_sgemmsup_rv_zen_asm_1x16n "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "memory" ) consider_edge_cases: diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c new file mode 100644 index 0000000000..3b93fc6802 --- /dev/null +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x16_mask.c @@ -0,0 +1,1785 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materia provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +static const int32_t mask[8][8] = { {0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + }; + +void bli_sgemmsup_rv_zen_asm_5x16_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,8*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,8*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 8*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,8*4)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + vfmadd231ps(ymm1, ymm2, ymm13) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm1) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + vmulps(ymm0, ymm12, ymm12) + vmulps(ymm0, ymm13, ymm13) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm1) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm7) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm9) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm11) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm1, ymm12) + vmovups(ymm12, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm1, ymm13) + vmaskmovps(ymm13, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx, 0*32)) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx, 0*32)) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm12, mem(rcx, 0*32)) + vmaskmovps(ymm13, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r13", "r14", + "xmm0", "xmm1", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x16_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,8*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,8*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 8*4)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + add(r9, rax) // a += cs_a; + + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + add(r9, rax) // a += cs_a; + + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + vfmadd231ps(ymm1, ymm2, ymm11) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm12) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm11, ymm11) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm12) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm7) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm9) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm10) + vmovups(ymm10, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm11) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx, 0*32)) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm10, mem(rcx, 0*32)) + vmaskmovps(ymm11, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r13", "r14", + "xmm0", "xmm12", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x16_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask values + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,8*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,8*4)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vfmadd231ps(ymm1, ymm2, ymm9) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm12) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm9, ymm9) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm12) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm7) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm8) + vmovups(ymm8, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm9) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovups(ymm8, mem(rcx, 0*32)) + vmaskmovps(ymm9, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0", "xmm12", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm8", "ymm9", "ymm12", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x16_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,8*4)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + vfmadd231ps(ymm1, ymm2, ymm7) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm14) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm7, ymm7) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm14) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm14, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm14, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vfmadd231ps(mem(rcx, 0*32), ymm14, ymm6) + vmovups(ymm6, mem(rcx, 0*32)) + + vmaskmovps(mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm14, ymm7) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + add(rdi, rcx) + + vmovups(ymm6, mem(rcx, 0*32)) + vmaskmovps(ymm7, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0", "xmm14", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", + "ymm7", "ymm14", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x16_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask values + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 8*4)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 7*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 8*cs_c + lea(mem(r12, rsi, 8), rdx) // rdx = c + 8*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 9*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, rcx, 1, 5*4)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 4, 5*4)) // prefetch c + 12*cs_c + lea(mem(r12, rcx, 4), rdx) // rdx = c + 12*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 14*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovups(mem(rbx, 0*32), ymm0) + vmaskmovps(mem(rbx, 1*32), ymm3, ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vfmadd231ps(ymm1, ymm2, ymm5) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm12) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm5, ymm5) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm12) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vfmadd231ps(mem(rcx, 0*32), ymm12, ymm4) + vmovups(ymm4, mem(rcx, 0*32)) + + vmaskmovps( mem(rcx, 1*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm12, ymm5) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmovups(ymm4, mem(rcx, 0*32)) + vmaskmovps(ymm5, ymm3, mem(rcx, 1*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0","xmm12", + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm12", + "memory" + ) +} diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c new file mode 100644 index 0000000000..55de26c884 --- /dev/null +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x4_mask.c @@ -0,0 +1,1568 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materia provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +static const int32_t mask[8][8] = { {0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + }; + +void bli_sgemmsup_rv_zen_asm_5x4_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + const int32_t *mask_vec = mask[n0]; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + vxorps(xmm12, xmm12, xmm12) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 0)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 0)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 0)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 0)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + vbroadcastss(mem(rax, r8, 4), xmm2) + add(r9, rax) // a += cs_a; + vfmadd231ps(xmm0, xmm2, xmm12) + + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + vmulps(xmm0, xmm12, xmm12) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm6) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm10) + vmaskmovps(xmm10, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) + vmaskmovps(xmm12, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm10, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm12, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm6", "xmm7", + "xmm8", "xmm10", "xmm12", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x4_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + const int32_t *mask_vec = mask[n0]; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + vxorps(xmm10, xmm10, xmm10) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 0)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 0)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 0)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vbroadcastss(mem(rax, r13, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm8) + vfmadd231ps(xmm0, xmm3, xmm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + vmulps(xmm0, xmm10, xmm10) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm6) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm10) + vmaskmovps(xmm10, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm8, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm10, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm6", "xmm7", + "xmm8", "xmm10", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x4_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + const int32_t *mask_vec = mask[n0]; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + vxorps(xmm8, xmm8, xmm8) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 0)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 0)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + vbroadcastss(mem(rax, r8, 2), xmm2) + vfmadd231ps(xmm0, xmm2, xmm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + vmulps(xmm0, xmm8, xmm8) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm6) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) + vmaskmovps(xmm8, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm6, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm8, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm6", "xmm7", + "xmm8", "xmm10", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x4_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + const int32_t *mask_vec = mask[n0]; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + vxorps(xmm1, xmm1, xmm1) + vxorps(xmm4, xmm4, xmm4) + vxorps(xmm6, xmm6, xmm6) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 0)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vbroadcastss(mem(rax, r8, 1), xmm3) + vfmadd231ps(xmm0, xmm2, xmm4) + vfmadd231ps(xmm0, xmm3, xmm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + vmulps(xmm0, xmm6, xmm6) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + + vmaskmovps(mem(rcx), xmm7, xmm1) + vfmadd231ps(xmm1, xmm3, xmm6) + vmaskmovps(xmm6, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + add(rdi, rcx) + vmaskmovps(xmm6, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm6", "xmm7", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x4_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + const int32_t *mask_vec = mask[n0]; + + // ------------------------------------------------------------------------- + begin_asm() + + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), xmm7) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + vxorps(xmm4, xmm4, xmm4) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 0)) // prefetch c + 0*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rdx, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + vmaskmovps(mem(rbx), xmm7, xmm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), xmm2) + vfmadd231ps(xmm0, xmm2, xmm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), xmm0) // load alpha and duplicate + vbroadcastss(mem(rbx), xmm3) // load beta and duplicate + + vmulps(xmm0, xmm4, xmm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. + vucomiss(xmm0, xmm3) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORED) // jump to column storage case + + + label(.SROWSTORED) + + vmaskmovps(mem(rcx), xmm7, xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) + vmaskmovps(xmm4, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (8*rs_c) == 8. + jz(.SCOLSTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(xmm4, xmm7, mem(rcx)) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", + "xmm0", "xmm2", "xmm3", + "xmm4", "xmm7", + "memory" + ) +} diff --git a/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c new file mode 100644 index 0000000000..74c1c51989 --- /dev/null +++ b/kernels/zen/3/sup/s6x16/bli_gemmsup_rv_zen_asm_s5x8_mask.c @@ -0,0 +1,1572 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materia provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" + +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +static const int32_t mask[8][8] = { {0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + }; + +void bli_sgemmsup_rv_zen_asm_5x8_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,4*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,4*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*4)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1,4*4)) // prefetch c + 4*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + vbroadcastss(mem(rax, r8, 4), ymm2) + vfmadd231ps(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm7) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + vmulps(ymm0, ymm12, ymm12) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm7) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm6) + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm8) + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm10) + vmaskmovps(ymm10, ymm3, mem(rcx, 0)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm12) + vmaskmovps(ymm12, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm10, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm12, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r13", "r14", + "xmm0", "xmm7", + "ymm0", "ymm2", "ymm3", "ymm4", "ymm6", + "ymm7", "ymm8", "ymm10", "ymm12", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_4x8_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,4*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,4*4)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 4*4)) // prefetch c + 3*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + vbroadcastss(mem(rax, r13, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm10) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm7) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + vmulps(ymm0, ymm10, ymm10) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm7) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm6) + vmaskmovps(ymm6, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm8) + vmaskmovps(ymm8, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm10) + vmaskmovps(ymm10, ymm3, mem(rcx, 0*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm10, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r13", "r14", + "xmm0", "xmm7", + "ymm0", "ymm2", "ymm3", "ymm4", "ymm6", + "ymm7", "ymm8", "ymm10", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_3x8_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,4*4)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2,4*4)) // prefetch c + 2*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.SLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + vbroadcastss(mem(rax, r8, 2), ymm2) + vfmadd231ps(ymm0, ymm2, ymm8) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm7) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + vmulps(ymm0, ymm8, ymm8) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm7) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm6) + vmaskmovps(ymm6, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm8) + vmaskmovps(ymm8, ymm3, mem(rcx, 0*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm8, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0", "xmm7", + "ymm0", "ymm2", "ymm3", "ymm4", "ymm6", + "ymm7", "ymm8", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_2x8_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load mask elements + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1,4*4)) // prefetch c + 1*rs_c + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 2 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 3 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + vbroadcastss(mem(rax, r8, 1), ymm2) + vfmadd231ps(ymm0, ymm2, ymm6) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm7) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + vmulps(ymm0, ymm6, ymm6) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm7) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0*32)) + + add(rdi, rcx) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm6) + vmaskmovps(ymm6, ymm3, mem(rcx, 0*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + add(rdi, rcx) + + vmaskmovps(ymm6, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + jmp(.SDONE) // jump to end. + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0", "xmm7", + "ymm0", "ymm2", "ymm3", "ymm4", "ymm6", "ymm7", + "memory" + ) +} + +void bli_sgemmsup_rv_zen_asm_1x8_mask + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + float* restrict alpha, + float* restrict a, inc_t rs_a0, inc_t cs_a0, + float* restrict b, inc_t rs_b0, inc_t cs_b0, + float* restrict beta, + float* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + + uint64_t n_mod8 = n0 % 8 ; + const int32_t *mask_vec = mask[n_mod8]; + // ------------------------------------------------------------------------- + + begin_asm() + + vzeroall() // zero all xmm/ymm registers. + mov(var(mask_vec), rdx) + vmovdqu(mem(rdx), ymm3) //load + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 4), r8) // rs_a *= sizeof(float) + + + lea(mem(, r9, 4), r9) // cs_a *= sizeof(float) + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 4), r10) // rs_b *= sizeof(float) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOLPFETCH) // jump to column storage case + label(.SROWPFETCH) // row-stored prefetching on c + + prefetch(0, mem(r12, 4*4)) // prefetch c + 0*rs + + jmp(.SPOSTPFETCH) // jump to end of prefetching c + label(.SCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + lea(mem(rsi, rsi, 2), rcx) // rcx = 3*cs_c; + prefetch(0, mem(r12, 5*4)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*4)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*4)) // prefetch c + 2*cs_c + prefetch(0, mem(r12, rcx, 1, 5*4)) // prefetch c + 3*cs_c + prefetch(0, mem(r12, rsi, 4, 5*4)) // prefetch c + 4*cs_c + lea(mem(r12, rsi, 4), rdx) // rdx = c + 4*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*4)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*4)) // prefetch c + 6*cs_c + + label(.SPOSTPFETCH) // done prefetching c + + mov(var(ps_a4), rdx) // load ps_a4 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a4 + // use rcx, rdx for prefetching lines + // from next upanel of a. + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.SCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + label(.SLOOPKITER) // MAIN LOOP + + // ---------------------------------- iteration 0 + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + // ---------------------------------- iteration 1 + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + + // ---------------------------------- iteration 2 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + + // ---------------------------------- iteration 3 + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKITER) // iterate again if i != 0. + + label(.SCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.SPOSTACCUM) // if i == 0, we're done; jump to end. + // ee, we prepare to enter k_left loop. + + label(.SLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmaskmovps(mem(rbx, 0*32), ymm3, ymm0) + add(r10, rbx) // b += rs_b; + + vbroadcastss(mem(rax ), ymm2) + vfmadd231ps(ymm0, ymm2, ymm4) + + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.SLOOPKLEFT) // iterate again if i != 0. + + label(.SPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastss(mem(rax), ymm0) // load alpha and duplicate + vbroadcastss(mem(rbx), ymm7) // load beta and duplicate + + vmulps(ymm0, ymm4, ymm4) // scale by alpha + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // rsi = cs_c * sizeof(float) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + lea(mem(rsi, rsi, 4), rbx) // rbx = 5*cs_c; + + // now avoid loading C if beta == 0 + + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomiss(xmm0, xmm7) // set ZF if beta == 0. + je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORED) // jump to column storage case + + label(.SROWSTORED) + + vmaskmovps(mem(rcx, 0*32), ymm3, ymm2) + vfmadd231ps(ymm2, ymm7, ymm4) + vmaskmovps(ymm4, ymm3, mem(rcx, 0*32)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORED) + + /* TODO: Add column storage support*/ + + label(.SBETAZERO) + + cmp(imm(4), rdi) // set ZF if (4*rs_c) == 4. + jz(.SCOTORBZ) // jump to column storage case + + label(.SROWSTORBZ) + + vmaskmovps(ymm4, ymm3, mem(rcx, 0)) + + jmp(.SDONE) // jump to end. + + label(.SCOTORBZ) + + /* TODO: Add column storage support*/ + + label(.SDONE) + + label(.SRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [mask_vec] "m" (mask_vec) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r12", "r14", + "xmm0", "xmm7", + "ymm0", "ymm2", "ymm3", "ymm4", "ymm7", + "memory" + ) +} + diff --git a/kernels/zen/CMakeLists.txt b/kernels/zen/CMakeLists.txt deleted file mode 100644 index 0ac346fb3e..0000000000 --- a/kernels/zen/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - - -set(SUBDIRECTORIES "1" "1f" "2" "3" "util") - -#Add all subdirectories -foreach(VAR ${SUBDIRECTORIES}) - add_subdirectory(${VAR}) -endforeach() - - diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index e6a2f33f92..45817f08be 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -110,6 +110,7 @@ SETV_KER_PROT(double, d, setv_zen_int) AXPYF_KER_PROT( float, s, axpyf_zen_int_8 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_8 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_16x4 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int_16x2 ) AXPYF_KER_PROT( float, s, axpyf_zen_int_5 ) AXPYF_KER_PROT( float, s, axpyf_zen_int_6 ) @@ -188,6 +189,33 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x8m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x2m ) +//gemmsup_rv (mkernel in m dim) for mask load/store +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16m_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x8m_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4m_mask ) +GEMMSUP_KER_PROT( float, s, bli_sgemmsup_rv_zen_asm_6x8m ) +GEMMSUP_KER_PROT( float, s, bli_sgemmsup_rv_zen_asm_6x4m ) +GEMMSUP_KER_PROT( float, s, bli_sgemmsup_rv_zen_asm_6x2m ) + +//gemmsup_rv (mkernel in m dim) for fringe case +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x16_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x16_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x16_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x16_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x16_mask ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x8_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x8_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x8_mask ) + +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x4_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4_mask ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4_mask ) + // gemmsup_rv (mkernel in n dim) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x16n ) @@ -259,6 +287,20 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) +err_t bli_dgemm_tiny +( + trans_t transa, + trans_t transb, + dim_t m, + dim_t n, + dim_t k, + const double* alpha, + const double* a, const inc_t rs_a0, const inc_t cs_a0, + const double* b, const inc_t rs_b0, const inc_t cs_b0, + const double* beta, + double* c, const inc_t rs_c0, const inc_t cs_c0 +); + err_t bli_dgemm_small ( obj_t* alpha, @@ -303,7 +345,7 @@ err_t bli_zgemm_small_At cntl_t* cntl ); -void bli_dgemm_8x6_avx2_k1_nn +err_t bli_dgemm_8x6_avx2_k1_nn ( dim_t m, dim_t n, @@ -315,7 +357,7 @@ void bli_dgemm_8x6_avx2_k1_nn double* c, const inc_t ldc ); -void bli_zgemm_4x6_avx2_k1_nn +void bli_zgemm_4x4_avx2_k1_nn ( dim_t m, dim_t n, @@ -426,3 +468,8 @@ void bli_dznorm2fv_unb_var1_avx2 double* norm, cntx_t* cntx ); + +GEMM_UKR_PROT( dcomplex, z, gemm_zen_asm_2x6) + +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_l_zen_asm_2x6) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_u_zen_asm_2x6) diff --git a/kernels/zen/lpgemm/math_utils_avx2.h b/kernels/zen/lpgemm/math_utils_avx2.h index e705adb8f7..5f503fa3e7 100644 --- a/kernels/zen/lpgemm/math_utils_avx2.h +++ b/kernels/zen/lpgemm/math_utils_avx2.h @@ -44,8 +44,8 @@ #define TBL_LN2 0x1.71547652b82fep+0 #define EXPF_HUGE 0x1.8p+23 -#define EXPF_MIN -88.7228393f -#define EXPF_MAX 88.7228393f +#define EXPF_MIN -88.0f +#define EXPF_MAX 88.0f #define inf 1.0/0.0 #define sign -2147483648 @@ -84,8 +84,8 @@ POLY_EVAL_6_AVX2 (r, r2, z); \ \ q = _mm256_add_epi32((__m256i) (r), _mm256_sllv_epi32 ((__m256i)dn, _mm256_set1_epi32 (23)) ); \ - q = (__m256i)_mm256_blendv_ps ((__m256)q, _mm256_set1_ps(inf), _mm256_cmp_ps (_mm256_set1_ps(88.0), x, 1)); \ - q = (__m256i)_mm256_blendv_ps ((__m256)q, _mm256_set1_ps(0.0), _mm256_cmp_ps (x, _mm256_set1_ps(-88.0), 1)); + q = (__m256i)_mm256_blendv_ps ((__m256)q, _mm256_set1_ps(inf), _mm256_cmp_ps (_mm256_set1_ps(EXPF_MAX), x, 1)); \ + q = (__m256i)_mm256_blendv_ps ((__m256)q, _mm256_set1_ps(0.0), _mm256_cmp_ps (x, _mm256_set1_ps(EXPF_MIN), 1)); #define TANHF_AVX2(x_tanh, r, r2, x, z, dn, q) \ x = _mm256_mul_ps (_mm256_andnot_ps(_mm256_set1_ps(-0.0f), x_tanh), _mm256_set1_ps(-2) ); \ @@ -112,7 +112,8 @@ \ POLY_EVAL_HORNER_16_0_AVX2(r,x); \ \ - x = _mm256_blendv_ps (x, _mm256_set1_ps(1), _mm256_cmp_ps (_mm256_set1_ps(3.9192059040069580078125f), r, 1)); \ + x = _mm256_blendv_ps (x, _mm256_set1_ps(1), _mm256_cmp_ps (_mm256_set1_ps(3.553f), r, 1)); \ + x = _mm256_blendv_ps (x, _mm256_set1_ps(1), _mm256_cmp_ps (_mm256_set1_ps(1.0f), x, 1)); \ x_erf = _mm256_or_ps(_mm256_and_ps (x_erf, (__m256)_mm256_set1_epi32(~(0x7FFFFFFF))), x); //Trignometric EXP, TANH and ERF functions for SSE @@ -132,8 +133,8 @@ POLY_EVAL_6_SSE (r, r2, z); \ \ q = _mm_add_epi32((__m128i) (r), _mm_sllv_epi32 ((__m128i)dn, _mm_set1_epi32 (23)) ); \ - q = (__m128i)_mm_blendv_ps ((__m128)q, _mm_set1_ps(inf), _mm_cmp_ps (_mm_set1_ps(88.0), x, 1)); \ - q = (__m128i)_mm_blendv_ps ((__m128)q, _mm_set1_ps(0.0), _mm_cmp_ps (x, _mm_set1_ps(-88.0), 1)); + q = (__m128i)_mm_blendv_ps ((__m128)q, _mm_set1_ps(inf), _mm_cmp_ps (_mm_set1_ps(EXPF_MAX), x, 1)); \ + q = (__m128i)_mm_blendv_ps ((__m128)q, _mm_set1_ps(0.0), _mm_cmp_ps (x, _mm_set1_ps(EXPF_MIN), 1)); #define TANHF_SSE(x_tanh, r, r2, x, z, dn, q) \ x = _mm_mul_ps (_mm_andnot_ps(_mm_set1_ps(-0.0f), x_tanh), _mm_set1_ps(-2) ); \ @@ -160,7 +161,8 @@ \ POLY_EVAL_HORNER_16_0_SSE(r,x); \ \ - x = _mm_blendv_ps (x, _mm_set1_ps(1), _mm_cmp_ps (_mm_set1_ps(3.9192059040069580078125f), r, 1)); \ + x = _mm_blendv_ps (x, _mm_set1_ps(1), _mm_cmp_ps (_mm_set1_ps(3.553f), r, 1)); \ + x = _mm_blendv_ps (x, _mm_set1_ps(1), _mm_cmp_ps (_mm_set1_ps(1.0f), x, 1)); \ x_erf = _mm_or_ps(_mm_and_ps (x_erf, (__m128)_mm_set1_epi32(~(0x7FFFFFFF))), x); #endif // AOCL_LPGEMM_MATH_UTILS_AVX2_H diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c index 8b41f0e6da..c102a89dea 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_6x32rowmajor_amd256.c @@ -151,7 +151,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) __m256i b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); - // Seperate register for intermediate op + // Separate register for intermediate op __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -168,7 +168,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -185,7 +185,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -201,7 +201,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -218,7 +218,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -236,7 +236,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -262,7 +262,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -278,7 +278,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -294,7 +294,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -311,7 +311,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -327,7 +327,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -343,7 +343,7 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -775,13 +775,28 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + __m128i _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m256i zero_point_0 = _mm256_setzero_si256(); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -792,13 +807,26 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c index 8d0bea859b..8d5a99968c 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_m_fringe_amd256.c @@ -104,7 +104,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -119,7 +119,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -134,7 +134,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -143,7 +143,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -167,7 +167,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -182,7 +182,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -197,7 +197,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -212,7 +212,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -521,6 +521,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -533,11 +535,25 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -548,11 +564,24 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -668,7 +697,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -677,7 +706,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -700,7 +729,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -715,7 +744,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_1 = _mm256_add_epi8( a_int32_1, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -918,6 +947,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -930,9 +961,23 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -943,9 +988,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1032,7 +1090,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -1055,7 +1113,7 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -1205,6 +1263,8 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1217,8 +1277,22 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -1229,8 +1303,21 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c index 79fa0bcd3f..9e2355a711 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_mn_fringe_amd256.c @@ -88,7 +88,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -101,7 +101,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -114,7 +114,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -127,7 +127,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -148,7 +148,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -161,7 +161,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -174,7 +174,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -187,7 +187,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -384,6 +384,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -396,11 +398,25 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -497,7 +513,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -510,7 +526,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -523,7 +539,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -536,7 +552,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -557,7 +573,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -570,7 +586,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -583,7 +599,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -596,7 +612,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -809,6 +825,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -820,11 +838,30 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -932,7 +969,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -945,7 +982,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -965,7 +1002,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -978,7 +1015,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1121,6 +1158,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1133,9 +1172,23 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1217,7 +1270,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1230,7 +1283,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 4. @@ -1250,7 +1303,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1263,7 +1316,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1418,6 +1471,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1429,9 +1484,28 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1520,7 +1594,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1540,7 +1614,7 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1656,6 +1730,8 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1668,8 +1744,22 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1746,7 +1836,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1766,7 +1856,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1892,6 +1982,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1903,8 +1995,27 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c index 69b7a9baa9..36cad252a6 100644 --- a/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c +++ b/kernels/zen/lpgemm/s8s8s16/lpgemm_s8_n_fringe_amd256.c @@ -102,7 +102,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -115,7 +115,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -128,7 +128,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -141,7 +141,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -154,7 +154,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -167,7 +167,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -188,7 +188,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -201,7 +201,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -214,7 +214,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -227,7 +227,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -240,7 +240,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -253,7 +253,7 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -505,6 +505,8 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -517,13 +519,27 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -698,7 +714,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -711,7 +727,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -724,7 +740,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -737,7 +753,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -750,7 +766,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -763,7 +779,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -796,7 +812,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -809,7 +825,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -822,7 +838,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -835,7 +851,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -848,7 +864,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) //convert signed int8 to uint8 for u8s8s16 FMA ops a_int32_0 = _mm256_add_epi8( a_int32_0, vec_uint8 ); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1120,6 +1136,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1131,13 +1149,32 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 859a377ce0..3c92c49da2 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -144,7 +144,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) __m256i b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); - // Seperate register for intermediate op + // Separate register for intermediate op __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -158,7 +158,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -172,7 +172,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -185,7 +185,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -199,7 +199,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -214,7 +214,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -237,7 +237,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) uint8_t a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); __m256i a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -250,7 +250,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -263,7 +263,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -277,7 +277,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -290,7 +290,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -303,7 +303,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -348,41 +348,82 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,ir,0,0,alphav,betav) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,ir,0,0,alphav,betav) - // c[0, 16-31] - S8_S16_BETA_OP(c_int16_0p1,ir,0,1,alphav,betav) + // c[0, 16-31] + S8_S16_BETA_OP(c_int16_0p1,ir,0,1,alphav,betav) - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,ir,1,0,alphav,betav) + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,ir,1,0,alphav,betav) - // c[1,16-31] - S8_S16_BETA_OP(c_int16_1p1,ir,1,1,alphav,betav) + // c[1,16-31] + S8_S16_BETA_OP(c_int16_1p1,ir,1,1,alphav,betav) - // c[2,0-15] - S8_S16_BETA_OP(c_int16_2p0,ir,2,0,alphav,betav) + // c[2,0-15] + S8_S16_BETA_OP(c_int16_2p0,ir,2,0,alphav,betav) - // c[2,16-31] - S8_S16_BETA_OP(c_int16_2p1,ir,2,1,alphav,betav) + // c[2,16-31] + S8_S16_BETA_OP(c_int16_2p1,ir,2,1,alphav,betav) - // c[3,0-15] - S8_S16_BETA_OP(c_int16_3p0,ir,3,0,alphav,betav) + // c[3,0-15] + S8_S16_BETA_OP(c_int16_3p0,ir,3,0,alphav,betav) - // c[3,16-31] - S8_S16_BETA_OP(c_int16_3p1,ir,3,1,alphav,betav) + // c[3,16-31] + S8_S16_BETA_OP(c_int16_3p1,ir,3,1,alphav,betav) - // c[4,0-15] - S8_S16_BETA_OP(c_int16_4p0,ir,4,0,alphav,betav) + // c[4,0-15] + S8_S16_BETA_OP(c_int16_4p0,ir,4,0,alphav,betav) - // c[4,16-31] - S8_S16_BETA_OP(c_int16_4p1,ir,4,1,alphav,betav) + // c[4,16-31] + S8_S16_BETA_OP(c_int16_4p1,ir,4,1,alphav,betav) - // c[5,0-15] - S8_S16_BETA_OP(c_int16_5p0,ir,5,0,alphav,betav) + // c[5,0-15] + S8_S16_BETA_OP(c_int16_5p0,ir,5,0,alphav,betav) - // c[5,16-31] - S8_S16_BETA_OP(c_int16_5p1,ir,5,1,alphav,betav) + // c[5,16-31] + S8_S16_BETA_OP(c_int16_5p1,ir,5,1,alphav,betav) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,ir,0,0,alphav,betav) + + // c[0, 16-31] + U8_S16_BETA_OP(c_int16_0p1,ir,0,1,alphav,betav) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,ir,1,0,alphav,betav) + + // c[1,16-31] + U8_S16_BETA_OP(c_int16_1p1,ir,1,1,alphav,betav) + + // c[2,0-15] + U8_S16_BETA_OP(c_int16_2p0,ir,2,0,alphav,betav) + + // c[2,16-31] + U8_S16_BETA_OP(c_int16_2p1,ir,2,1,alphav,betav) + + // c[3,0-15] + U8_S16_BETA_OP(c_int16_3p0,ir,3,0,alphav,betav) + + // c[3,16-31] + U8_S16_BETA_OP(c_int16_3p1,ir,3,1,alphav,betav) + + // c[4,0-15] + U8_S16_BETA_OP(c_int16_4p0,ir,4,0,alphav,betav) + + // c[4,16-31] + U8_S16_BETA_OP(c_int16_4p1,ir,4,1,alphav,betav) + + // c[5,0-15] + U8_S16_BETA_OP(c_int16_5p0,ir,5,0,alphav,betav) + + // c[5,16-31] + U8_S16_BETA_OP(c_int16_5p1,ir,5,1,alphav,betav) + } } else { @@ -703,20 +744,35 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) /* Load the scale vector values into the register*/ scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (0 * 8)); + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 0 * 8 ) ); scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_ops_attr.post_op_c_j + (1 * 8)); + ( float* )post_ops_list_temp->scale_factor + + post_ops_attr.post_op_c_j + ( 1 * 8 ) ); + + // Load zero points (2 byte values). + __m128i _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m256i zero_point_0 = _mm256_setzero_si256(); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -727,13 +783,26 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -745,25 +814,49 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - // c[0,0-31] - CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); - // c[1,0-31] - CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); + // c[1,0-31] + CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); - // c[2,0-31] - CVT_STORE_S16_S8(c_int16_2p0, c_int16_2p1, 2, 0); + // c[2,0-31] + CVT_STORE_S16_S8(c_int16_2p0, c_int16_2p1, 2, 0); - // c[3,0-31] - CVT_STORE_S16_S8(c_int16_3p0, c_int16_3p1, 3, 0); + // c[3,0-31] + CVT_STORE_S16_S8(c_int16_3p0, c_int16_3p1, 3, 0); - // c[4,0-31] - CVT_STORE_S16_S8(c_int16_4p0, c_int16_4p1, 4, 0); + // c[4,0-31] + CVT_STORE_S16_S8(c_int16_4p0, c_int16_4p1, 4, 0); - // c[5,0-31] - CVT_STORE_S16_S8(c_int16_5p0, c_int16_5p1, 5, 0); - } + // c[5,0-31] + CVT_STORE_S16_S8(c_int16_5p0, c_int16_5p1, 5, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_U8(c_int16_0p0, c_int16_0p1, 0, 0); + + // c[1,0-31] + CVT_STORE_S16_U8(c_int16_1p0, c_int16_1p1, 1, 0); + + // c[2,0-31] + CVT_STORE_S16_U8(c_int16_2p0, c_int16_2p1, 2, 0); + + // c[3,0-31] + CVT_STORE_S16_U8(c_int16_3p0, c_int16_3p1, 3, 0); + + // c[4,0-31] + CVT_STORE_S16_U8(c_int16_4p0, c_int16_4p1, 4, 0); + + // c[5,0-31] + CVT_STORE_S16_U8(c_int16_5p0, c_int16_5p1, 5, 0); + } + } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. else diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c index 863c57a5b6..b6094c878d 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_m_fringe_amd256.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,7 +95,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) // Broadcast a[1,kr:kr+2]. a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -107,7 +107,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) // Broadcast a[2,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -119,7 +119,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) // Broadcast a[3,kr:kr+2]. a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -128,7 +128,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -149,7 +149,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -161,7 +161,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_1 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -173,7 +173,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -185,7 +185,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_1 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -223,29 +223,58 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) - // c[0, 16-31] - S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + // c[0, 16-31] + S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) - // c[1,16-31] - S8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) + // c[1,16-31] + S8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) - // c[2,0-15] - S8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) + // c[2,0-15] + S8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) - // c[2,16-31] - S8_S16_BETA_OP(c_int16_2p1,0,2,1,selector1,selector2) + // c[2,16-31] + S8_S16_BETA_OP(c_int16_2p1,0,2,1,selector1,selector2) - // c[3,0-15] - S8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) + // c[3,0-15] + S8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) - // c[3,16-31] - S8_S16_BETA_OP(c_int16_3p1,0,3,1,selector1,selector2) + // c[3,16-31] + S8_S16_BETA_OP(c_int16_3p1,0,3,1,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + U8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + U8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) + + // c[2,0-15] + U8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) + + // c[2,16-31] + U8_S16_BETA_OP(c_int16_2p1,0,2,1,selector1,selector2) + + // c[3,0-15] + U8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) + + // c[3,16-31] + U8_S16_BETA_OP(c_int16_3p1,0,3,1,selector1,selector2) + } } else { @@ -473,6 +502,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -485,11 +516,25 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -500,11 +545,24 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -516,18 +574,36 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - // c[0,0-31] - CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + + // c[1,0-31] + CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); - // c[1,0-31] - CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); + // c[2,0-31] + CVT_STORE_S16_S8(c_int16_2p0, c_int16_2p1, 2, 0); - // c[2,0-31] - CVT_STORE_S16_S8(c_int16_2p0, c_int16_2p1, 2, 0); + // c[3,0-31] + CVT_STORE_S16_S8(c_int16_3p0, c_int16_3p1, 3, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_U8(c_int16_0p0, c_int16_0p1, 0, 0); + + // c[1,0-31] + CVT_STORE_S16_U8(c_int16_1p0, c_int16_1p1, 1, 0); - // c[3,0-31] - CVT_STORE_S16_S8(c_int16_3p0, c_int16_3p1, 3, 0); + // c[2,0-31] + CVT_STORE_S16_U8(c_int16_2p0, c_int16_2p1, 2, 0); + + // c[3,0-31] + CVT_STORE_S16_U8(c_int16_3p0, c_int16_3p1, 3, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -611,7 +687,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) // Broadcast a[1,kr:kr+2]. a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -620,7 +696,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -640,7 +716,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -652,7 +728,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_1 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); @@ -684,17 +760,34 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) - - // c[0, 16-31] - S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) - - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) - - // c[1,16-31] - S8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + S8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + U8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + + // c[1,16-31] + U8_S16_BETA_OP(c_int16_1p1,0,1,1,selector1,selector2) + } } else { @@ -838,6 +931,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -850,9 +945,23 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -863,9 +972,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -877,12 +999,24 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - // c[0,0-31] - CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); - // c[1,0-31] - CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); + // c[1,0-31] + CVT_STORE_S16_S8(c_int16_1p0, c_int16_1p1, 1, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_U8(c_int16_0p0, c_int16_0p1, 0, 0); + + // c[1,0-31] + CVT_STORE_S16_U8(c_int16_1p0, c_int16_1p1, 1, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -946,7 +1080,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -966,7 +1100,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); @@ -995,11 +1129,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) - - // c[0, 16-31] - S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + S8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[0, 16-31] + U8_S16_BETA_OP(c_int16_0p1,0,0,1,selector1,selector2) + } } else { @@ -1101,6 +1246,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1113,8 +1260,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) scale_1 = _mm256_loadu_ps( @@ -1125,8 +1286,21 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (3 * 8)); + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale next 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p1, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1138,9 +1312,18 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - // c[0,0-31] - CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_S8(c_int16_0p0, c_int16_0p1, 0, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + // c[0,0-31] + CVT_STORE_S16_U8(c_int16_0p0, c_int16_0p1, 0, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c index e4b04e80e1..b19abe413d 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -82,7 +82,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -92,7 +92,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -102,7 +102,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // Broadcast a[2,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -112,7 +112,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // Broadcast a[3,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -130,7 +130,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -140,7 +140,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -150,7 +150,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -160,7 +160,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -192,17 +192,34 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) - - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) - - // c[2,0-15] - S8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) - - // c[3,0-15] - S8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + + // c[2,0-15] + S8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) + + // c[3,0-15] + S8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + + // c[2,0-15] + U8_S16_BETA_OP(c_int16_2p0,0,2,0,selector1,selector2) + + // c[3,0-15] + U8_S16_BETA_OP(c_int16_3p0,0,3,0,selector1,selector2) + } } else { @@ -343,6 +360,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -355,11 +374,25 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 4 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -371,14 +404,28 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int32). + __m128i temp[2]; - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); - // c[2-3,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); + // c[2-3,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + + // c[2-3,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -450,7 +497,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -460,7 +507,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -470,7 +517,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // Broadcast a[2,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -480,7 +527,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // Broadcast a[3,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -498,7 +545,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -508,7 +555,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -518,7 +565,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -528,7 +575,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -560,24 +607,48 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + if ( post_ops_attr.c_stor_type == S8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); - // c[0,0-15] - S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + // c[0,0-15] + S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) - // c[1,0-15] - S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + // c[1,0-15] + S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) - // c[2,0-15] - S8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) + // c[2,0-15] + S8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) - // c[3,0-15] - S8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) + // c[3,0-15] + S8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + + // c[0,0-15] + U8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + + // c[2,0-15] + U8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) + + // c[3,0-15] + U8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) + } } else { @@ -727,6 +798,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -738,11 +811,30 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -754,21 +846,42 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + + // c[2-3,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); + + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); - // c[2-3,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); + // c[2-3,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -844,7 +957,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -854,7 +967,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -871,7 +984,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -881,7 +994,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -909,11 +1022,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) - - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,0,1,0,selector1,selector2) + } } else { @@ -1012,6 +1136,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1024,9 +1150,23 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1038,11 +1178,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int32). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -1102,7 +1253,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1112,7 +1263,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 4. @@ -1129,7 +1280,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1139,7 +1290,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1167,16 +1318,32 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + if ( post_ops_attr.c_stor_type == S8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); - // c[0,0-15] - S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + // c[0,0-15] + S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) - // c[1,0-15] - S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + // c[1,0-15] + S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + + // c[0,0-15] + U8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + } } else { @@ -1282,6 +1449,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1293,9 +1462,28 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1307,16 +1495,32 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -1378,7 +1582,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1395,7 +1599,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1421,8 +1625,16 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,0,0,0,selector1,selector2) + } } else { @@ -1500,6 +1712,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -1512,8 +1726,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1525,12 +1753,24 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; - __m256i zero_reg = _mm256_setzero_si256(); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); + + // c[0-1,0-15] + CVT_STORE_S16_S8_1ROW(c_int16_0p0, zero_reg, 0, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); - // c[0-1,0-15] - CVT_STORE_S16_S8_1ROW(c_int16_0p0, zero_reg, 0, 0); + // c[0-1,0-15] + CVT_STORE_S16_U8_1ROW(c_int16_0p0, zero_reg, 0, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -1584,7 +1824,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1601,7 +1841,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -1627,12 +1867,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + if ( post_ops_attr.c_stor_type == S8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - // c[0,0-15] - S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + // c[0,0-15] + S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + } + if ( post_ops_attr.c_stor_type == U8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + + // c[0,0-15] + U8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + } } else { @@ -1716,6 +1968,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1727,8 +1981,27 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 2 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1740,16 +2013,32 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; - __m256i zero_reg = _mm256_setzero_si256(); + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); - // c[0-1,0-15] - CVT_STORE_S16_S8_1ROW_NLT16(c_int16_0p0, zero_reg, buf0); + // c[0-1,0-15] + CVT_STORE_S16_S8_1ROW_NLT16(c_int16_0p0, zero_reg, buf0); - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; + __m256i zero_reg = _mm256_setzero_si256(); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + // c[0-1,0-15] + CVT_STORE_S16_U8_1ROW_NLT16(c_int16_0p0, zero_reg, buf0); + + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); + + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c index a3270f3091..1947de5542 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_n_fringe_amd256.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -96,7 +96,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -106,7 +106,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -116,7 +116,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[2,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -126,7 +126,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[3,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -136,7 +136,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[4,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -146,7 +146,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // Broadcast a[5,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -164,7 +164,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -174,7 +174,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -184,7 +184,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -194,7 +194,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -204,7 +204,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -214,7 +214,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -250,23 +250,46 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - // c[0,0-15] - S8_S16_BETA_OP(c_int16_0p0,ir,0,0,selector1,selector2) + if ( post_ops_attr.c_stor_type == S8 ) + { + // c[0,0-15] + S8_S16_BETA_OP(c_int16_0p0,ir,0,0,selector1,selector2) - // c[1,0-15] - S8_S16_BETA_OP(c_int16_1p0,ir,1,0,selector1,selector2) + // c[1,0-15] + S8_S16_BETA_OP(c_int16_1p0,ir,1,0,selector1,selector2) - // c[2,0-15] - S8_S16_BETA_OP(c_int16_2p0,ir,2,0,selector1,selector2) + // c[2,0-15] + S8_S16_BETA_OP(c_int16_2p0,ir,2,0,selector1,selector2) - // c[3,0-15] - S8_S16_BETA_OP(c_int16_3p0,ir,3,0,selector1,selector2) + // c[3,0-15] + S8_S16_BETA_OP(c_int16_3p0,ir,3,0,selector1,selector2) - // c[4,0-15] - S8_S16_BETA_OP(c_int16_4p0,ir,4,0,selector1,selector2) + // c[4,0-15] + S8_S16_BETA_OP(c_int16_4p0,ir,4,0,selector1,selector2) - // c[5,0-15] - S8_S16_BETA_OP(c_int16_5p0,ir,5,0,selector1,selector2) + // c[5,0-15] + S8_S16_BETA_OP(c_int16_5p0,ir,5,0,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // c[0,0-15] + U8_S16_BETA_OP(c_int16_0p0,ir,0,0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP(c_int16_1p0,ir,1,0,selector1,selector2) + + // c[2,0-15] + U8_S16_BETA_OP(c_int16_2p0,ir,2,0,selector1,selector2) + + // c[3,0-15] + U8_S16_BETA_OP(c_int16_3p0,ir,3,0,selector1,selector2) + + // c[4,0-15] + U8_S16_BETA_OP(c_int16_4p0,ir,4,0,selector1,selector2) + + // c[5,0-15] + U8_S16_BETA_OP(c_int16_5p0,ir,5,0,selector1,selector2) + } } else { @@ -449,6 +472,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; /* Load the scale vector values into the register*/ @@ -461,13 +486,27 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) (float *)post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + (1 * 8)); + // Load zero points (2 byte values). + _zero_point_0 = + _mm_loadu_si128( + ( __m128i const* )( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + if ( post_ops_attr.c_stor_type == S8 ) + { + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -479,17 +518,34 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + + // c[2-3,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); + + // c[4-5,0-15] + CVT_STORE_S16_S8_2ROW(c_int16_4p0, c_int16_5p0, 4, 5, 0); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_0p0, c_int16_1p0, 0, 1, 0); - // c[2-3,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); + // c[2-3,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_2p0, c_int16_3p0, 2, 3, 0); - // c[4-5,0-15] - CVT_STORE_S16_S8_2ROW(c_int16_4p0, c_int16_5p0, 4, 5, 0); + // c[4-5,0-15] + CVT_STORE_S16_U8_2ROW(c_int16_4p0, c_int16_5p0, 4, 5, 0); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. @@ -636,7 +692,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[0,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -646,7 +702,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[1,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -656,7 +712,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[2,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -666,7 +722,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[3,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -676,7 +732,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[4,kr:kr+2]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -686,7 +742,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Broadcast a[5,kr:kr+4]. a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -713,7 +769,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -723,7 +779,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -733,7 +789,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -743,7 +799,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -753,7 +809,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); - // Seperate register for intermediate op + // Separate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. @@ -789,32 +845,64 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_first_k == TRUE ) ) { - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + if ( post_ops_attr.c_stor_type == S8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); - S8_S16_BETA_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); + S8_S16_BETA_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); - // c[0,0-15] - S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + // c[0,0-15] + S8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) - // c[1,0-15] - S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + // c[1,0-15] + S8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) - // c[2,0-15] - S8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) + // c[2,0-15] + S8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) - // c[3,0-15] - S8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) + // c[3,0-15] + S8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) - // c[4,0-15] - S8_S16_BETA_OP_NLT16(c_int16_4p0,buf4,selector1,selector2) + // c[4,0-15] + S8_S16_BETA_OP_NLT16(c_int16_4p0,buf4,selector1,selector2) - // c[5,0-15] - S8_S16_BETA_OP_NLT16(c_int16_5p0,buf5,selector1,selector2) + // c[5,0-15] + S8_S16_BETA_OP_NLT16(c_int16_5p0,buf5,selector1,selector2) + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( uint8_t ); + + U8_S16_BETA_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); + U8_S16_BETA_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); + + // c[0,0-15] + U8_S16_BETA_OP_NLT16(c_int16_0p0,buf0,selector1,selector2) + + // c[1,0-15] + U8_S16_BETA_OP_NLT16(c_int16_1p0,buf1,selector1,selector2) + + // c[2,0-15] + U8_S16_BETA_OP_NLT16(c_int16_2p0,buf2,selector1,selector2) + + // c[3,0-15] + U8_S16_BETA_OP_NLT16(c_int16_3p0,buf3,selector1,selector2) + + // c[4,0-15] + U8_S16_BETA_OP_NLT16(c_int16_4p0,buf4,selector1,selector2) + + // c[5,0-15] + U8_S16_BETA_OP_NLT16(c_int16_5p0,buf5,selector1,selector2) + } } else { @@ -1008,6 +1096,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) __m256i temp_32[2]; __m256 temp_float[2]; __m256 scale_1, scale_2; + __m128i _zero_point_0; + __m256i zero_point_0 = _mm256_setzero_si256(); __m256 res_1, res_2; float float_buf[16]; @@ -1019,13 +1109,32 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + if ( post_ops_attr.c_stor_type == S8 ) + { + int8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( int8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepi8_epi16( _zero_point_0 ); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + uint8_t zero_point_buf[16]; + + memcpy( zero_point_buf, ( ( uint8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ), ( n0_rem * sizeof( uint8_t ) ) ); + _zero_point_0 = _mm_loadu_si128( ( __m128i const* )zero_point_buf ); + zero_point_0 = _mm256_cvtepu8_epi16( _zero_point_0 ); + } + // Scale first 16 columns of the 6 rows. - CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2) - CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2) + CVT_MULRND_CVT16(c_int16_0p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_1p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_2p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_3p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_4p0, scale_1, scale_2, zero_point_0) + CVT_MULRND_CVT16(c_int16_5p0, scale_1, scale_2, zero_point_0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1037,26 +1146,52 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) if ( ( post_ops_attr.buf_downscale != NULL ) && ( post_ops_attr.is_last_k == TRUE ) ) { - // Store the results in downscaled type (int8 instead of int32). - __m128i temp[2]; + if ( post_ops_attr.c_stor_type == S8 ) + { + // Store the results in downscaled type (int8 instead of int16). + __m128i temp[2]; + + // c[0-1,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + + // c[2-3,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); + + // c[4-5,0-15] + CVT_STORE_S16_S8_2ROW_NLT16(c_int16_4p0, c_int16_5p0, buf4, buf5); - // c[0-1,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); + CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); + } + else if ( post_ops_attr.c_stor_type == U8 ) + { + // Store the results in downscaled type (uint8 instead of int16). + __m128i temp[2]; - // c[2-3,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); + // c[0-1,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_0p0, c_int16_1p0, buf0, buf1); - // c[4-5,0-15] - CVT_STORE_S16_S8_2ROW_NLT16(c_int16_4p0, c_int16_5p0, buf4, buf5); + // c[2-3,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_2p0, c_int16_3p0, buf2, buf3); - dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + // c[4-5,0-15] + CVT_STORE_S16_U8_2ROW_NLT16(c_int16_4p0, c_int16_5p0, buf4, buf5); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); - CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); + dim_t n0_rem_dscale_bytes = n0_rem * sizeof( int8_t ); + + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf0, 0, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf1, 1, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf2, 2, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf3, 3, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf4, 4, n0_rem_dscale_bytes); + CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf5, 5, n0_rem_dscale_bytes); + } } // Case where the output C matrix is s16 or is the temp buffer used to // store intermediate s16 accumulated values for downscaled (C-s8) api. diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c b/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c index ef629707f1..1169f825c8 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_packb_amd256.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h b/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h index 1ce68ed498..48a95ccd53 100644 --- a/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h +++ b/kernels/zen/lpgemm/u8s8s16/lpgemm_s16_kern_macros.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -77,7 +77,7 @@ scratch1 = _mm256_loadu_si256( ( __m256i const* )buf_ ); \ S16_BETA_FMA(reg,scratch1,scratch2) \ -// Downscale beta scale macro, scratch2=beta +// Downscale beta scale macro (s8 -> s16), scratch2=beta #define S8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ scratch1 = \ _mm256_cvtepi8_epi16 \ @@ -91,22 +91,47 @@ ); \ S16_BETA_FMA(reg,scratch1,scratch2) \ -// Downscale beta n < 16 scale macro, scratch2=beta +// Downscale beta scale macro (u8 -> s16), scratch2=beta +#define U8_S16_BETA_OP(reg,m_ir,m_ind,n_ind,scratch1,scratch2) \ + scratch1 = \ + _mm256_cvtepu8_epi16 \ + ( \ + _mm_loadu_si128 \ + ( \ + ( __m128i const* )( ( uint8_t* )post_ops_attr.buf_downscale + \ + ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \ + post_ops_attr.post_op_c_j + ( n_ind * 16 ) )\ + ) \ + ); \ + S16_BETA_FMA(reg,scratch1,scratch2) \ + +// Downscale beta n < 16 scale macro (s8 -> s16), scratch2=beta #define S8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \ scratch1 = _mm256_cvtepi8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \ S16_BETA_FMA(reg,scratch1,scratch2) \ -#define S8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ +// Downscale beta n < 16 scale macro (u8 -> s16), scratch2=beta +#define U8_S16_BETA_OP_NLT16(reg,buf_,scratch1,scratch2) \ + scratch1 = _mm256_cvtepu8_epi16( _mm_loadu_si128( ( __m128i const* )buf_ ) ); \ + S16_BETA_FMA(reg,scratch1,scratch2) \ + +#define US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \ memcpy \ ( \ buf_, \ - ( ( int8_t* )post_ops_attr.buf_downscale + \ + ( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \ post_ops_attr.post_op_c_j ), bytes \ ); \ + +#define S8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ + US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,int8_t) \ + +#define U8_S16_BETA_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ + US8_S16_BETA_NLT16_MEMCP_HELPER(buf_,m_ind,bytes,uint8_t) \ // Downscale macro -#define CVT_MULRND_CVT16(reg, scale0, scale1) \ +#define CVT_MULRND_CVT16(reg, scale0, scale1, zero_point_0) \ \ /* Extract the first 128 bits of the register*/ \ temp[0] = _mm256_extractf128_si256( reg, 0 ); \ @@ -122,33 +147,17 @@ res_1 = _mm256_mul_ps( temp_float[0], scale0 ); \ res_2 = _mm256_mul_ps( temp_float[1], scale1 ); \ \ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */ \ + /* Round the resultant value to the nearest float value. */ \ res_1 = \ - _mm256_min_ps \ - ( \ - _mm256_max_ps \ - ( \ _mm256_round_ps \ ( \ res_1, ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ - ), \ - _mm256_set1_ps( ( float )S8_MIN ) \ - ), \ - _mm256_set1_ps( ( float )S8_MAX ) \ - );\ + ); \ res_2 = \ - _mm256_min_ps \ - ( \ - _mm256_max_ps \ - ( \ _mm256_round_ps \ ( \ res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) \ - ), \ - _mm256_set1_ps( ( float )S8_MIN ) \ - ), \ - _mm256_set1_ps( ( float )S8_MAX ) \ - );\ + ); \ \ /* Convert the clipped float32 scaled rounded value to int32 */ \ temp_32[0] = _mm256_cvtps_epi32( res_1 ); \ @@ -159,97 +168,156 @@ \ /*Permute to make sure the order is correct*/ \ reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ + \ + /* Zero point addition.*/ \ + reg = _mm256_add_epi16( reg, zero_point_0 ); \ -// Downscale store macro -#define CVT_STORE_S16_S8(reg0, reg1, m_ind, n_ind) \ - /* Convert the s16 to s8 */ \ - reg0 = _mm256_packs_epi16( reg0, reg1 ); \ - reg0 = _mm256_permute4x64_epi64( reg0, 0XD8 ); \ +// Downscale store macro helper +#define CVT_STORE_S16_SU8_HELPER(reg, m_ind, n_ind, C_type) \ + reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ \ _mm256_storeu_si256 \ ( \ - ( __m256i* )( ( int8_t* )post_ops_attr.buf_downscale + \ + ( __m256i* )( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \ post_ops_attr.post_op_c_j + ( n_ind * 32 ) ), \ - reg0 \ - ) \ + reg \ + ); \ -// Downscale store macro for fringe cases -#define CVT_STORE_S16_S8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \ - /* Convert the s16 to s8 */ \ +// Downscale store macro (s16 -> s8) +#define CVT_STORE_S16_S8(reg0, reg1, m_ind, n_ind) \ + /* Convert the s16 to s8 */ \ reg0 = _mm256_packs_epi16( reg0, reg1 ); \ - reg0 = _mm256_permute4x64_epi64( reg0, 0XD8 ); \ + CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, int8_t) \ + +// Downscale store macro (s16 -> u8) +#define CVT_STORE_S16_U8(reg0, reg1, m_ind, n_ind) \ + /* Convert the s16 to s8 */ \ + reg0 = _mm256_packus_epi16( reg0, reg1 ); \ + CVT_STORE_S16_SU8_HELPER(reg0, m_ind, n_ind, uint8_t) \ + +// Downscale store helper macro for fringe cases +#define CVT_STORE_S16_US8_2ROW_HELPER(reg, m_ind0, m_ind1, n_ind, C_type) \ + reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ \ /* Extract the first 128 bits of the register*/ \ - temp[0] = _mm256_extractf128_si256( reg0, 0 ); \ + temp[0] = _mm256_extractf128_si256( reg, 0 ); \ /* Extract the second 128 bits of the register*/ \ - temp[1] = _mm256_extractf128_si256( reg0, 1 ); \ + temp[1] = _mm256_extractf128_si256( reg, 1 ); \ \ _mm_storeu_si128 \ ( \ - ( __m128i* )( ( int8_t* )post_ops_attr.buf_downscale + \ + ( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \ post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \ temp[0] \ ); \ _mm_storeu_si128 \ ( \ - ( __m128i* )( ( int8_t* )post_ops_attr.buf_downscale + \ + ( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind1 ) ) + \ post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \ temp[1] \ ); \ -// Downscale store macro for fringe cases -#define CVT_STORE_S16_S8_1ROW(reg0, reg1, m_ind0, n_ind) \ +// Downscale store macro for fringe cases (s16 -> s8) +#define CVT_STORE_S16_S8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \ /* Convert the s16 to s8 */ \ reg0 = _mm256_packs_epi16( reg0, reg1 ); \ - reg0 = _mm256_permute4x64_epi64( reg0, 0XD8 ); \ + CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, int8_t) \ + +// Downscale store macro for fringe cases (s16 -> u8) +#define CVT_STORE_S16_U8_2ROW(reg0, reg1, m_ind0, m_ind1, n_ind) \ + /* Convert the s16 to u8 */ \ + reg0 = _mm256_packus_epi16( reg0, reg1 ); \ + CVT_STORE_S16_US8_2ROW_HELPER(reg0, m_ind0, m_ind1, n_ind, uint8_t) \ + +// Downscale store helper macro for fringe cases +#define CVT_STORE_S16_US8_1ROW(reg, m_ind0, n_ind, C_type) \ + reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ \ /* Extract the first 128 bits of the register*/ \ - temp[0] = _mm256_extractf128_si256( reg0, 0 ); \ + temp[0] = _mm256_extractf128_si256( reg, 0 ); \ \ _mm_storeu_si128 \ ( \ - ( __m128i* )( ( int8_t* )post_ops_attr.buf_downscale + \ + ( __m128i* )( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind0 ) ) + \ post_ops_attr.post_op_c_j + ( n_ind * 16 ) ), \ temp[0] \ ); \ -// Downscale store macro for n < 16 fringe cases -#define CVT_STORE_S16_S8_2ROW_NLT16(reg0, reg1, buf0, buf1) \ +// Downscale store (s16 -> s8) macro for fringe cases +#define CVT_STORE_S16_S8_1ROW(reg0, reg1, m_ind0, n_ind) \ /* Convert the s16 to s8 */ \ reg0 = _mm256_packs_epi16( reg0, reg1 ); \ - reg0 = _mm256_permute4x64_epi64( reg0, 0XD8 ); \ + CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, int8_t) \ + +// Downscale store (s16 -> u8) macro for fringe cases +#define CVT_STORE_S16_U8_1ROW(reg0, reg1, m_ind0, n_ind) \ + /* Convert the s16 to u8 */ \ + reg0 = _mm256_packus_epi16( reg0, reg1 ); \ + CVT_STORE_S16_US8_1ROW(reg0, m_ind0, n_ind, uint8_t) \ + +// Downscale store helper macro for n < 16 fringe cases +#define CVT_STORE_S16_US8_2ROW_NLT16(reg, buf0, buf1) \ + reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ \ /* Extract the first 128 bits of the register*/ \ - temp[0] = _mm256_extractf128_si256( reg0, 0 ); \ + temp[0] = _mm256_extractf128_si256( reg, 0 ); \ /* Extract the second 128 bits of the register*/ \ - temp[1] = _mm256_extractf128_si256( reg0, 1 ); \ + temp[1] = _mm256_extractf128_si256( reg, 1 ); \ \ _mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \ _mm_storeu_si128( ( __m128i* )buf1, temp[1] ); \ -// Downscale store macro for n < 16 fringe cases -#define CVT_STORE_S16_S8_1ROW_NLT16(reg0, reg1, buf0) \ +// Downscale store (int16 -> s8) macro for n < 16 fringe cases +#define CVT_STORE_S16_S8_2ROW_NLT16(reg0, reg1, buf0, buf1) \ /* Convert the s16 to s8 */ \ reg0 = _mm256_packs_epi16( reg0, reg1 ); \ - reg0 = _mm256_permute4x64_epi64( reg0, 0XD8 ); \ + CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \ + +// Downscale store (int16 -> u8) macro for n < 16 fringe cases +#define CVT_STORE_S16_U8_2ROW_NLT16(reg0, reg1, buf0, buf1) \ + /* Convert the s16 to s8 */ \ + reg0 = _mm256_packus_epi16( reg0, reg1 ); \ + CVT_STORE_S16_US8_2ROW_NLT16(reg0, buf0, buf1) \ + +// Downscale store helper macro for n < 16 fringe cases +#define CVT_STORE_S16_US8_1ROW_NLT16(reg, buf0) \ + reg = _mm256_permute4x64_epi64( reg, 0XD8 ); \ \ /* Extract the first 128 bits of the register*/ \ - temp[0] = _mm256_extractf128_si256( reg0, 0 ); \ + temp[0] = _mm256_extractf128_si256( reg, 0 ); \ \ _mm_storeu_si128( ( __m128i* )buf0, temp[0] ); \ -#define CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ +// Downscale store (s16 -> s8) macro for n < 16 fringe cases +#define CVT_STORE_S16_S8_1ROW_NLT16(reg0, reg1, buf0) \ + /* Convert the s16 to s8 */ \ + reg0 = _mm256_packs_epi16( reg0, reg1 ); \ + CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \ + +// Downscale store (s16 -> u8) macro for n < 16 fringe cases +#define CVT_STORE_S16_U8_1ROW_NLT16(reg0, reg1, buf0) \ + /* Convert the s16 to u8 */ \ + reg0 = _mm256_packus_epi16( reg0, reg1 ); \ + CVT_STORE_S16_US8_1ROW_NLT16(reg0, buf0) \ + +#define CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, C_type) \ memcpy \ ( \ - ( ( int8_t* )post_ops_attr.buf_downscale + \ + ( ( C_type* )post_ops_attr.buf_downscale + \ ( post_ops_attr.rs_c_downscale * ( post_ops_attr.post_op_c_i + m_ind ) ) + \ post_ops_attr.post_op_c_j ), buf_, bytes \ ); \ +#define CVT_STORE_S16_S8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ + CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, int8_t) \ + +#define CVT_STORE_S16_U8_NLT16_MEMCP_UTIL(buf_,m_ind,bytes) \ + CVT_STORE_S16_US8_NLT16_MEMCP_HELPER(buf_,m_ind,bytes, uint8_t) \ + //-------------------------------------------------------------------------- /* GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */ #define GELU_TANH_S16_AVX2(reg, y1, y2, r, r2, x, z, dn, x_tanh, q) \ diff --git a/kernels/zen/util/CMakeLists.txt b/kernels/zen/util/CMakeLists.txt deleted file mode 100644 index 502ebd1ac2..0000000000 --- a/kernels/zen/util/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_thresh_funcs_zen.c - ) diff --git a/kernels/zen2/bli_kernels_zen2.h b/kernels/zen2/bli_kernels_zen2.h index db3bf2c26c..49833b0822 100644 --- a/kernels/zen2/bli_kernels_zen2.h +++ b/kernels/zen2/bli_kernels_zen2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1/CMakeLists.txt b/kernels/zen4/1/CMakeLists.txt deleted file mode 100644 index 9bfb5d650e..0000000000 --- a/kernels/zen4/1/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen4_1 - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int_avx512.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scalv_zen_int_avx512.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotv_zen_int_avx512.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int_avx512.c - ) - -target_compile_options(zen4_1 PRIVATE /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen4_1 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c index 85c3f0d356..ebb6290ad0 100644 --- a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -278,7 +278,7 @@ void bli_samaxv_zen_int_avx512( mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); // Type cast mask from IEEE754 (float) to integer type // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the + // the compiler. But its accounted as separate register in the // above calculations intMask = _mm512_castps_si512(mask.v); // Extract the signbit and build the mask. @@ -312,7 +312,7 @@ void bli_samaxv_zen_int_avx512( mask.v = _mm512_sub_ps(max_vec_2.v, x_vec_2.v); // Type cast mask from IEEE754 (float) to integer type // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the + // the compiler. But its accounted as separate register in the // above calculations intMask = _mm512_castps_si512(mask.v); // Extract the signbit and build the mask. @@ -345,7 +345,7 @@ void bli_samaxv_zen_int_avx512( mask.v = _mm512_sub_ps(max_vec_3.v, x_vec_3.v); // Type cast mask from IEEE754 (float) to integer type // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the + // the compiler. But its accounted as separate register in the // above calculations intMask = _mm512_castps_si512(mask.v); // Extract the signbit and build the mask. @@ -397,7 +397,7 @@ void bli_samaxv_zen_int_avx512( mask.v = _mm512_sub_ps(max_vec_2.v, max_vec_3.v); // Type cast mask from IEEE754 (float) to integer type // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the + // the compiler. But its accounted as separate register in the // above calculations intMask = _mm512_castps_si512(mask.v); // Extract the signbit and build the mask. @@ -423,7 +423,7 @@ void bli_samaxv_zen_int_avx512( mask.v = _mm512_sub_ps(max_vec_1.v, max_vec_2.v); // Type cast mask from IEEE754 (float) to integer type // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the + // the compiler. But its accounted as separate register in the // above calculations intMask = _mm512_castps_si512(mask.v); // Extract the signbit and build the mask. diff --git a/kernels/zen4/1/bli_axpyv_zen_int_avx512.c b/kernels/zen4/1/bli_axpyv_zen_int_avx512.c index 23b1f2f039..181a5a38ee 100644 --- a/kernels/zen4/1/bli_axpyv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_axpyv_zen_int_avx512.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1/bli_dotv_zen_int_avx512.c b/kernels/zen4/1/bli_dotv_zen_int_avx512.c index 681e4bda5b..4d9708e751 100644 --- a/kernels/zen4/1/bli_dotv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_dotv_zen_int_avx512.c @@ -334,8 +334,13 @@ void bli_ddotv_zen_int_avx512 x0 += 2 * n_elem_per_reg; y0 += 2 * n_elem_per_reg; } + rhov[0] = _mm512_add_pd(rhov[0], rhov[2]); + rhov[1] = _mm512_add_pd(rhov[1], rhov[3]); - for (; (i + 7) < n; i += 8) + rhov[0] = _mm512_add_pd(rhov[0], rhov[4]); + rhov[0] = _mm512_add_pd(rhov[0], rhov[1]); + + if((i + 7) < n) { xv[0] = _mm512_loadu_pd(x0); @@ -345,57 +350,33 @@ void bli_ddotv_zen_int_avx512 x0 += n_elem_per_reg; y0 += n_elem_per_reg; + i += 8; } - - __m256d temp[2]; - temp[0] = _mm256_setzero_pd(); - - for (; (i + 3) < n; i += 4) + if(i < n) { - __m256d x_vec = _mm256_loadu_pd(x0); - - __m256d y_vec = _mm256_loadu_pd(y0); - - temp[0] = _mm256_fmadd_pd(x_vec, y_vec, temp[0]); - - x0 += 4; - y0 += 4; - } - - __m128d temp_128[2]; - temp_128[0] = _mm_setzero_pd(); + // calculate mask based on remainder elements of vector + // which are not in multiple of 8. + // Here bitmask is prepared based on remainder elements + // to load only required elements from memory into + // vector register. + //for example if n-i=3 case bitmask is prepared as following. + //1 is shifted by n-i(3), mask becomes 0b1000. + //substracting 1 from it makes mask 0b111 which states that + //3 elements from memory are to be loaded into vector register. + __mmask8 mask = (1 << (n-i)) - 1; + rhov[1] = _mm512_setzero_pd(); + + xv[0] = _mm512_mask_loadu_pd(rhov[1], mask, x0); + + yv[0] = _mm512_mask_loadu_pd(rhov[1], mask, y0); - for (; (i + 1) < n; i += 2) - { - __m128d x_vec = _mm_loadu_pd(x0 + 0 * n_elem_per_reg); - - __m128d y_vec = _mm_loadu_pd(y0 + 0 * n_elem_per_reg); - - temp_128[0] = _mm_fmadd_pd(x_vec, y_vec, temp_128[0]); + rhov[0] = _mm512_fmadd_pd(xv[0], yv[0], rhov[0]); - x0 += 2; - y0 += 2; + x0 += (n-i); + y0 += (n-i); + i += (n-i); } - - // Add the results from above to finish the sum. - rhov[0] = _mm512_add_pd(rhov[0], rhov[2]); - rhov[1] = _mm512_add_pd(rhov[1], rhov[3]); - - rhov[0] = _mm512_add_pd(rhov[0], rhov[1]); - rhov[0] = _mm512_add_pd(rhov[0], rhov[4]); - - temp[1] = _mm512_extractf64x4_pd(rhov[0], 0); - temp[0] = _mm256_add_pd(temp[0], temp[1]); - - temp[1] = _mm512_extractf64x4_pd(rhov[0], 1); - temp[0] = _mm256_add_pd(temp[0], temp[1]); - - temp_128[1] = _mm256_extractf64x2_pd(temp[0], 0); - temp_128[0] = _mm_add_pd(temp_128[0], temp_128[1]); - temp_128[1] = _mm256_extractf64x2_pd(temp[0], 1); - temp_128[0] = _mm_add_pd(temp_128[0], temp_128[1]); - - rho0 = temp_128[0][0] + temp_128[0][1]; + rho0 = _mm512_reduce_add_pd(rhov[0]); } for (; i < n; ++i) diff --git a/kernels/zen4/1/bli_scalv_zen_int_avx512.c b/kernels/zen4/1/bli_scalv_zen_int_avx512.c index 2dd355b268..febd6aa8e9 100644 --- a/kernels/zen4/1/bli_scalv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_scalv_zen_int_avx512.c @@ -269,6 +269,29 @@ void bli_dscalv_zen_int_avx512 cntx_t *restrict cntx ) { + // If the vector dimension is zero, or if alpha is unit, return early. + if (bli_zero_dim1(n) || PASTEMAC(d, eq1)(*alpha)) + return; + + // If alpha is zero, use setv. + if (PASTEMAC(d, eq0)(*alpha)) + { + double *zero = bli_d0; + if (cntx == NULL) cntx = bli_gks_query_cntx(); + dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_SETV_KER, cntx); + + f + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx + ); + + return; + } + dim_t i = 0; double *restrict x0; @@ -417,3 +440,150 @@ void bli_dscalv_zen_int_avx512 x0 += incx; } } + +/* + Functionality + ------------- + + This function scales a double complex vector by an element of the + type double. + + x := conjalpha(alpha) * x + + Function Signature + ------------------- + + * 'conjalpha' - Variable specified if alpha needs to be conjugated + * 'n' - Length of the array passed + * 'alpha' - Pointer to the element by which the vector is to be scaled + * 'x' - Double complex pointer pointing to an array + * 'incx' - Stride to point to the next element in the array + * 'cntx' - BLIS context object + + Exception + ---------- + + None + + Deviation from BLAS + -------------------- + + None + + Undefined behaviour + ------------------- + + 1. The kernel results in undefined behaviour when n <= 0 and incx <= 1. The expectation + is that these are standard BLAS exceptions and should be handled in a higher layer. +*/ +void bli_zdscalv_zen_int_avx512 + ( + conj_t conjalpha, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx + ) +{ + /* + This kernel only performs the computation + when alpha is double from the BLAS layer + alpha is passed as double complex to adhere + to function pointer definition in BLIS + */ + const double alphac = (*alpha).real; + + dim_t i = 0; + + double *restrict x0 = (double *)x; + + if (incx == 1) + { + __m512d alphav, xv[4]; + const dim_t n_elem_per_reg = 8; // number of elements per register + + alphav = _mm512_set1_pd(alphac); + + for (; (i + 15) < n; i += 16) + { + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + xv[2] = _mm512_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm512_loadu_pd(x0 + 3 * n_elem_per_reg); + + xv[0] = _mm512_mul_pd(alphav, xv[0]); + xv[1] = _mm512_mul_pd(alphav, xv[1]); + xv[2] = _mm512_mul_pd(alphav, xv[2]); + xv[3] = _mm512_mul_pd(alphav, xv[3]); + + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + _mm512_storeu_pd(x0 + 2 * n_elem_per_reg, xv[2]); + _mm512_storeu_pd(x0 + 3 * n_elem_per_reg, xv[3]); + + x0 += 4 * n_elem_per_reg; + } + + for (; (i + 7) < n; i += 8) + { + xv[0] = _mm512_loadu_pd(x0); + xv[1] = _mm512_loadu_pd(x0 + n_elem_per_reg); + + xv[0] = _mm512_mul_pd(alphav, xv[0]); + xv[1] = _mm512_mul_pd(alphav, xv[1]); + + _mm512_storeu_pd(x0, xv[0]); + _mm512_storeu_pd(x0 + n_elem_per_reg, xv[1]); + + x0 += 2 * n_elem_per_reg; + } + + for (; (i + 3) < n; i += 4) + { + xv[0] = _mm512_loadu_pd(x0); + + xv[0] = _mm512_mul_pd(alphav, xv[0]); + + _mm512_storeu_pd(x0, xv[0]); + + x0 += n_elem_per_reg; + } + + for (; (i + 1) < n; i += 2) + { + __m256d xv = _mm256_loadu_pd(x0); + + __m256d alphav = _mm256_set1_pd(alphac); + + xv = _mm256_mul_pd(alphav, xv); + + _mm256_storeu_pd(x0, xv); + + x0 += 4; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from AVX to SSE instructions (which may occur as soon + // as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + } + + /* In double complex data type the computation of + unit stride elements can still be vectorized using SSE*/ + __m128d alpha_reg, x_vec; + + alpha_reg = _mm_set1_pd((*alpha).real); + + for (; i < n; ++i) + { + x_vec = _mm_loadu_pd(x0); + + x_vec = _mm_mul_pd(x_vec, alpha_reg); + + _mm_storeu_pd(x0, x_vec); + + x0 += 2 * incx; + } +} diff --git a/kernels/zen4/1m/CMakeLists.txt b/kernels/zen4/1m/CMakeLists.txt deleted file mode 100644 index 9dfbefc458..0000000000 --- a/kernels/zen4/1m/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen4_1m - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_d8xk.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_d16xk.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_d24xk.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_d32xk.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_z12xk.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen4_asm_z4xk.c - ) - -target_compile_options(zen4_1m PRIVATE /U__PRFCHW__ /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen4_1m PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c index 5ecc5403f7..c311d4ebf2 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d16xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c index ee9e128e41..4c7151513e 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d24xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c index 1ff964069a..60df4bca4e 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_d32xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_z12xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_z12xk.c index 3145801e11..fc33908cc8 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_z12xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_z12xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/1m/bli_packm_zen4_asm_z4xk.c b/kernels/zen4/1m/bli_packm_zen4_asm_z4xk.c index 02f2776c17..3716d83e36 100644 --- a/kernels/zen4/1m/bli_packm_zen4_asm_z4xk.c +++ b/kernels/zen4/1m/bli_packm_zen4_asm_z4xk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/3/CMakeLists.txt b/kernels/zen4/3/CMakeLists.txt index 0b38920998..6573f85ed8 100644 --- a/kernels/zen4/3/CMakeLists.txt +++ b/kernels/zen4/3/CMakeLists.txt @@ -10,6 +10,11 @@ add_library(zen4_3 ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_8x24.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small_AVX512.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_12x4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zero_zmm.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_4x12.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemmtrsm_l_4x12.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemmtrsm_u_4x12.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_avx512_k1.c ) target_compile_options(zen4_3 PRIVATE /arch:AVX2 /arch:AVX512) diff --git a/kernels/zen4/3/bli_dgemm_avx512_k1.c b/kernels/zen4/3/bli_dgemm_avx512_k1.c new file mode 100644 index 0000000000..e3c15c78c5 --- /dev/null +++ b/kernels/zen4/3/bli_dgemm_avx512_k1.c @@ -0,0 +1,6556 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "immintrin.h" + + +#define D_MR 24 +#define D_NR 8 + +err_t bli_dgemm_24x8_avx512_k1_nn +( + dim_t m, + dim_t n, + dim_t k, + double* alpha, + double* a, const inc_t lda, + double* b, const inc_t ldb, + double* beta, + double* c, const inc_t ldc +) +{ + err_t ret_status = BLIS_FAILURE; + double alpha_val, beta_val; + + beta_val = *beta; + alpha_val = *alpha; + + dim_t m_remainder = (m % D_MR); + dim_t n_remainder = (n % D_NR); + + //scratch registers + __m512d zmm0, zmm1, zmm2, zmm3; + __m512d zmm4, zmm5, zmm6, zmm7; + __m512d zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15; + __m512d zmm16, zmm17, zmm18, zmm19; + __m512d zmm20, zmm21, zmm22, zmm23; + __m512d zmm24, zmm25, zmm26, zmm27; + __m512d zmm28, zmm29, zmm30, zmm31; + + if(alpha_val != 0.0 && beta_val != 0.0) + { + /* Compute C = alpha*A*B + beta*c */ + for(dim_t j = 0; (j + (D_NR-1) < n ); j += D_NR) + { + double* temp_b = b + j*ldb; + double* temp_a = a; + double* temp_c = c + j*ldc; + + for(dim_t i = 0; i < ( m - D_MR+1); i += D_MR) + { + //Clear out vector registers to hold fma result. + //zmm6 to zmm29 holds fma result. + //zmm0, zmm1, zmm2 are used to load 24 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm29 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + //Compute A*B. + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm2, zmm31, zmm29); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + zmm29 = _mm512_mul_pd(zmm0, zmm29); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm3 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm4 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm5 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm9 = _mm512_fmadd_pd(zmm3, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm4, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm5, zmm31, zmm11); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm3 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm4 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm5 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm15 = _mm512_fmadd_pd(zmm3, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm4, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm5, zmm31, zmm17); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm3 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm4 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm5 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm21 = _mm512_fmadd_pd(zmm3, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm4, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm5, zmm31, zmm23); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm31, zmm26); + + //zmm0, zmm1, zmm2 are used to load 24 elements from + //matrix C. + zmm3 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7)); + zmm4 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7 + 8)); + zmm5 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm27 = _mm512_fmadd_pd(zmm3, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm4, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm5, zmm31, zmm29); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm29. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 16), zmm26); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 8), zmm28); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 16), zmm29); + + //Update temp_c and temp_a pointer to + //respective offset. + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + //Handles the edge case for m_remainder from 17 to 23. + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm29 holds fma result. + //zmm0, zmm1, zmm2 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm29 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + //Compute A*B. + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm2, zmm31, zmm29); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + zmm29 = _mm512_mul_pd(zmm0, zmm29); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 3 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 4 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 5 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 6 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm31, zmm26); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 7 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm2, zmm31, zmm29); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm29. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*6 + 16), k0, zmm26); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 8), zmm28); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*7 + 16), k0, zmm29); + } + //Handles the edge cases where m_remainder is from 9 to 16 + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm28 holds fma result. + //zmm0, zmm1 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 3 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 4 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 5 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 6 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + + //zmm0, zmm1 are used to load 24 elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 7)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 7 + 8)); + //Compute C * Beta + fma result(AB*Alpha) + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm28. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6 + 8), k0, zmm25); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*7 + 8), k0, zmm28); + } + //Handles the edge case where m_remainder is from 1 to 8 + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm27 holds fma result. + //zmm0 are used to load 8 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + //Compute C * Beta + fma result(AB*Alpha) + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + //Compute C * Beta + fma result(AB*Alpha) + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + //Compute C * Beta + fma result(AB*Alpha) + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 3)); + //Compute C * Beta + fma result(AB*Alpha) + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 4)); + //Compute C * Beta + fma result(AB*Alpha) + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 5)); + //Compute C * Beta + fma result(AB*Alpha) + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 6)); + //Compute C * Beta + fma result(AB*Alpha) + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + + //zmm0 used to load 8 elements from + //matrix C. + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 7)); + //Compute C * Beta + fma result(AB*Alpha) + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + + //Store the result back to Matrix C. + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + //C matrix 2nd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + //C matrix 3rd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + //C matrix 4th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + //C matrix 5th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + //C matrix 6th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + //C matrix 7th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6), k0, zmm24); + //C matrix 8th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*7), k0, zmm27); + } + } + + switch(n_remainder) + { + case 7: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + //Clear out vector registers to hold fma result. + //zmm6 to zmm26 holds fma result. + //zmm0, zmm1, zmm2 are used to load 24 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 16)); + //Compute C * Beta + fma result(AB*Alpha) + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm31, zmm26); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm26. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 16), zmm26); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + //Handles the edge case where m_remainder is from 17 to 23 + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm26 holds fma result. + //zmm0, zmm1, zmm2 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>16)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 3 + 16)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 4 + 16)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 5 + 16)); + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + //zmm0, zmm1, zmm2 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 6 + 16)); + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm31, zmm26); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm26. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*6 + 16), k0, zmm26); + + } + //Handles the edge case where m_remadiner is from 9 to 16. + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>8)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 3 + 8)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 4 + 8)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 5 + 8)); + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + //zmm0, zmm1 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 6)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 6 + 8)); + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm31, zmm25); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm25. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6 + 8), k0, zmm25); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>1)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + //Broadcast Beta into zmm31 + zmm31 = _mm512_set1_pd(beta_val); + //zmm0 are used to load elements from + //matrix C. + //Compute C * Beta + fma result(AB*Alpha) + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 3)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 4)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 5)); + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 6)); + zmm24 = _mm512_fmadd_pd(zmm0, zmm31, zmm24); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm24. + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + //C matrix 2nd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + //C matrix 3rd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + //C matrix 4th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + //C matrix 5th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + //C matrix 6th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + //C matrix 7th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6), k0, zmm24); + } + break; + } + case 6: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 16)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 16)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 4 + 16)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 5 + 16)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 3 + 8)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 4 + 8)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 5)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 5 + 8)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 3)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 4)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 5)); + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + } + break; + } + case 5: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x5 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 16)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 4 + 16)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm31, zmm20); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 3 + 8)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 4)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 4 + 8)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm31, zmm19); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 3)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 4)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm31, zmm18); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + } + break; + } + case 4: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 3 + 16)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 3)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 3 + 8)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 3)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + } + break; + } + case 3: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2 + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc * 2 + 16)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm31, zmm14); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x3 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc * 2)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc * 2 + 8)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm31, zmm13); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x3 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc * 2)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm31, zmm12); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + } + break; + } + case 2: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + ldc + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + ldc + 16)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c + ldc )); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + ldc + 8)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c + ldc )); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + } + break; + } + case 1: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_loadu_pd((double const *)(temp_c + 8)); + zmm2 = _mm512_mask_loadu_pd(zmm2, k0, (double const *)(temp_c + 16)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm31, zmm8); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_loadu_pd((double const *)(temp_c)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_c + 8)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm31, zmm7); + + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + + zmm31 = _mm512_set1_pd(beta_val); + + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_c)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm31, zmm6); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + } + break; + } + default: + { + break; + } + } + ret_status = BLIS_SUCCESS; + } + else if(alpha_val != 0.0 && beta_val == 0.0) + { + /* Compute C = alpha*A*B + beta*c */ + for(dim_t j = 0; (j + (D_NR-1) < n ); j += D_NR) + { + double* temp_b = b + j*ldb; + double* temp_a = a; + double* temp_c = c + j*ldc; + + for(dim_t i = 0; i < ( m - D_MR+1); i += D_MR) + { + //Clear out vector registers to hold fma result. + //zmm6 to zmm29 holds fma result. + //zmm0, zmm1, zmm2 are used to load 24 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm29 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + //Compute A*B. + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm2, zmm31, zmm29); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + zmm29 = _mm512_mul_pd(zmm0, zmm29); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm29. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 16), zmm26); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 8), zmm28); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 16), zmm29); + + //Update temp_c and temp_a pointer to + //respective offset. + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + //Handles the edge case for m_remainder from 17 to 23. + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm29 holds fma result. + //zmm0, zmm1, zmm2 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm29 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + //Compute A*B. + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + zmm29 = _mm512_fmadd_pd(zmm2, zmm31, zmm29); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + zmm29 = _mm512_mul_pd(zmm0, zmm29); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm29. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*6 + 16), k0, zmm26); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_storeu_pd((double *)(temp_c + ldc*7 + 8), zmm28); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*7 + 16), k0, zmm29); + } + //Handles the edge cases where m_remainder is from 9 to 16 + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm28 holds fma result. + //zmm0, zmm1 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm28 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + zmm28 = _mm512_fmadd_pd(zmm1, zmm31, zmm28); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + zmm28 = _mm512_mul_pd(zmm0, zmm28); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm28. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6 + 8), k0, zmm25); + //C matrix 8th column + _mm512_storeu_pd((double *)(temp_c + ldc*7), zmm27); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*7 + 8), k0, zmm28); + } + //Handles the edge case where m_remainder is from 1 to 8 + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm27 holds fma result. + //zmm0 are used to load 8 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm27 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x8 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 7)); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm27 = _mm512_fmadd_pd(zmm0, zmm31, zmm27); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm27 = _mm512_mul_pd(zmm0, zmm27); + + //Store the result back to Matrix C. + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + //C matrix 2nd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + //C matrix 3rd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + //C matrix 4th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + //C matrix 5th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + //C matrix 6th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + //C matrix 7th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6), k0, zmm24); + //C matrix 8th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*7), k0, zmm27); + } + } + + switch(n_remainder) + { + case 7: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + //Clear out vector registers to hold fma result. + //zmm6 to zmm26 holds fma result. + //zmm0, zmm1, zmm2 are used to load 24 elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm26. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 16), zmm26); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + //Handles the edge case where m_remainder is from 17 to 23 + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + //Clear out vector registers to hold fma result. + //zmm6 to zmm26 holds fma result. + //zmm0, zmm1, zmm2 are used to load elements from + //A matrix. + //zmm30 and zmm31 are alternatively used to broadcast element + //from B matrix. + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm26 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>16)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + //Broadcast element from B matrix in zmm30 + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + //Broadcast element from next column of B matrix in zmm31 + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + zmm26 = _mm512_fmadd_pd(zmm2, zmm30, zmm26); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + zmm26 = _mm512_mul_pd(zmm0, zmm26); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm26. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_storeu_pd((double *)(temp_c + ldc*6 + 8), zmm25); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*6 + 16), k0, zmm26); + + } + //Handles the edge case where m_remadiner is from 9 to 16. + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm25 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>8)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + //Compute A*B. + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + //Compute A*B. + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + //Compute A*B. + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + //Compute A*B. + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + //Compute A*B. + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + //Compute A*B. + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + //Compute A*B. + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + zmm25 = _mm512_fmadd_pd(zmm1, zmm30, zmm25); + + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + zmm25 = _mm512_mul_pd(zmm0, zmm25); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm25. + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + //C matrix 2nd column + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + //C matrix 3rd column + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + //C matrix 4th column + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + //C matrix 5th column + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + //C matrix 6th column + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + //C matrix 7th column + _mm512_storeu_pd((double *)(temp_c + ldc*6), zmm24); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6 + 8), k0, zmm25); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm24 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with (>1)x7 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 6)); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm24 = _mm512_fmadd_pd(zmm0, zmm30, zmm24); + //Broadcast Alpha into zmm0 + zmm0 = _mm512_set1_pd(alpha_val); + //Scale fma result with Alpha. + //Alpha * AB + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm24 = _mm512_mul_pd(zmm0, zmm24); + + //Store the result back to Matrix C. + //Result is available in zmm6 to zmm24. + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + //C matrix 2nd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + //C matrix 3rd column + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + //C matrix 4th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + //C matrix 5th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + //C matrix 6th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + //C matrix 7th column + _mm512_mask_storeu_pd((double *)(temp_c + ldc*6), k0, zmm24); + } + break; + } + case 6: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 16), zmm23); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm23 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + zmm23 = _mm512_fmadd_pd(zmm2, zmm31, zmm23); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + zmm23 = _mm512_mul_pd(zmm0, zmm23); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_storeu_pd((double *)(temp_c + ldc*5 + 8), zmm22); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*5 + 16), k0, zmm23); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm22 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + zmm22 = _mm512_fmadd_pd(zmm1, zmm31, zmm22); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + + zmm21 = _mm512_mul_pd(zmm0, zmm21); + zmm22 = _mm512_mul_pd(zmm0, zmm22); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + + _mm512_storeu_pd((double *)(temp_c + ldc*5), zmm21); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5 + 8), k0, zmm22); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm21 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 5)); + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm21 = _mm512_fmadd_pd(zmm0, zmm31, zmm21); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm21 = _mm512_mul_pd(zmm0, zmm21); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*5), k0, zmm21); + } + break; + } + case 5: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x5 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 16), zmm20); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm20 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + zmm20 = _mm512_fmadd_pd(zmm2, zmm30, zmm20); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + zmm20 = _mm512_mul_pd(zmm0, zmm20); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_storeu_pd((double *)(temp_c + ldc*4 + 8), zmm19); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc*4 + 16), k0, zmm20); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm19 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + zmm19 = _mm512_fmadd_pd(zmm1, zmm30, zmm19); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + zmm18 = _mm512_mul_pd(zmm0, zmm18); + zmm19 = _mm512_mul_pd(zmm0, zmm19); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + _mm512_storeu_pd((double *)(temp_c + ldc*4), zmm18); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4 + 8), k0, zmm19); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm18 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 4)); + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm18 = _mm512_fmadd_pd(zmm0, zmm30, zmm18); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm18 = _mm512_mul_pd(zmm0, zmm18); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*4), k0, zmm18); + } + break; + } + case 4: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 16), zmm17); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm17 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + zmm17 = _mm512_fmadd_pd(zmm2, zmm31, zmm17); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + zmm17 = _mm512_mul_pd(zmm0, zmm17); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_storeu_pd((double *)(temp_c + ldc * 3 + 8), zmm16); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 3 + 16), k0, zmm17); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm16 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + zmm16 = _mm512_fmadd_pd(zmm1, zmm31, zmm16); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + zmm15 = _mm512_mul_pd(zmm0, zmm15); + zmm16 = _mm512_mul_pd(zmm0, zmm16); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + _mm512_storeu_pd((double *)(temp_c + ldc*3), zmm15); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 3 + 8), k0, zmm16); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm15 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x4 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 3)); + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm15 = _mm512_fmadd_pd(zmm0, zmm31, zmm15); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm15 = _mm512_mul_pd(zmm0, zmm15); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc*3), k0, zmm15); + } + break; + } + case 3: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 16), zmm14); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm14 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + zmm14 = _mm512_fmadd_pd(zmm2, zmm30, zmm14); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + zmm14 = _mm512_mul_pd(zmm0, zmm14); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_storeu_pd((double *)(temp_c + ldc * 2 + 8), zmm13); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc * 2 + 16), k0, zmm14); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm13 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x3 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + zmm13 = _mm512_fmadd_pd(zmm1, zmm30, zmm13); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + zmm12 = _mm512_mul_pd(zmm0, zmm12); + zmm13 = _mm512_mul_pd(zmm0, zmm13); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + _mm512_storeu_pd((double *)(temp_c + ldc * 2), zmm12); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2 + 8), k0, zmm13); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm12 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x3 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 2)); + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm12 = _mm512_fmadd_pd(zmm0, zmm30, zmm12); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm12 = _mm512_mul_pd(zmm0, zmm12); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc * 2), k0, zmm12); + } + break; + } + case 2: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + _mm_prefetch((char*)( temp_a + 192), _MM_HINT_T0); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_storeu_pd((double *)(temp_c + ldc + 16), zmm11); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm11 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + zmm11 = _mm512_fmadd_pd(zmm2, zmm31, zmm11); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + zmm11 = _mm512_mul_pd(zmm0, zmm11); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_storeu_pd((double *)(temp_c + ldc + 8), zmm10); + _mm512_mask_storeu_pd ((double *)(temp_c + ldc + 16), k0, zmm11); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm10 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + zmm10 = _mm512_fmadd_pd(zmm1, zmm31, zmm10); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + zmm9 = _mm512_mul_pd(zmm0, zmm9); + zmm10 = _mm512_mul_pd(zmm0, zmm10); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + + _mm512_storeu_pd((double *)(temp_c + ldc), zmm9); + _mm512_mask_storeu_pd((double *)(temp_c + ldc + 8), k0, zmm10); + + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm9 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x2 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm31 = _mm512_set1_pd(*(double const *)(temp_b + ldb * 1)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm9 = _mm512_fmadd_pd(zmm0, zmm31, zmm9); + + zmm0 = _mm512_set1_pd(alpha_val); + + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm9 = _mm512_mul_pd(zmm0, zmm9); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + ldc), k0, zmm9); + } + break; + } + case 1: + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 24x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_loadu_pd((double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm0 = _mm512_set1_pd(alpha_val); + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_storeu_pd((double *)(temp_c + 16), zmm8); + + temp_c += D_MR; + temp_a += D_MR; + } + dim_t m_rem = m_remainder; + if(m_rem > 16) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm8 = _mm512_setzero_pd(); + zmm2 = _mm512_setzero_pd(); + + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >16x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_loadu_pd((double const *)(temp_a + 8)); + zmm2 = _mm512_mask_loadu_pd (zmm2, k0, (double const *)(temp_a + 16)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + zmm8 = _mm512_fmadd_pd(zmm2, zmm30, zmm8); + + zmm0 = _mm512_set1_pd(alpha_val); + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + zmm8 = _mm512_mul_pd(zmm0, zmm8); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_storeu_pd((double *)(temp_c + 8), zmm7); + _mm512_mask_storeu_pd ((double *)(temp_c + 16), k0, zmm8); + + } + else if(m_rem > 8) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm7 = _mm512_setzero_pd(); + zmm1 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >8x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_loadu_pd((double const *)(temp_a)); + zmm1 = _mm512_mask_loadu_pd(zmm1, k0, (double const *)(temp_a + 8)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + zmm7 = _mm512_fmadd_pd(zmm1, zmm30, zmm7); + + zmm0 = _mm512_set1_pd(alpha_val); + zmm6 = _mm512_mul_pd(zmm0, zmm6); + zmm7 = _mm512_mul_pd(zmm0, zmm7); + + _mm512_storeu_pd((double *)(temp_c), zmm6); + _mm512_mask_storeu_pd((double *)(temp_c + 8), k0, zmm7); + } + else if(m_rem > 0) + { + uint8_t mask = (0xff >> (0x8 - (m & 7))); // calculate mask based on m_remainder + if (mask == 0) mask = 0xff; + __mmask8 k0 = _load_mask8(&mask); + zmm6 = _mm512_setzero_pd(); + zmm0 = _mm512_setzero_pd(); + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with >1x1 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + zmm0 = _mm512_mask_loadu_pd(zmm0, k0, (double const *)(temp_a)); + + zmm30 = _mm512_set1_pd(*(double const *)(temp_b)); + zmm6 = _mm512_fmadd_pd(zmm0, zmm30, zmm6); + + zmm0 = _mm512_set1_pd(alpha_val); + zmm6 = _mm512_mul_pd(zmm0, zmm6); + + _mm512_mask_storeu_pd((double *)(temp_c), k0, zmm6); + } + break; + } + default: + { + break; + } + } + ret_status = BLIS_SUCCESS; + } + else + { + ;//return failure; + } + return ret_status; + +} diff --git a/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c b/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c index c20c0ab898..cab5ea0ce5 100644 --- a/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c +++ b/kernels/zen4/3/bli_dgemm_zen4_asm_32x6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -474,7 +474,12 @@ void bli_dgemm_zen4_asm_32x6( [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "r13", "r14", "r15", "k0", "k1", "k2", "k3", "k4", "xmm1", + "xmm2", "ymm2", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "ymm16", + "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", + "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", + "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", diff --git a/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c b/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c index 1f133dfc15..887f27889c 100644 --- a/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c +++ b/kernels/zen4/3/bli_dgemm_zen4_asm_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -218,7 +218,7 @@ static int64_t offsets[24] __attribute__((aligned(64))) = /* * number of accumulation registers = 24/8 * 8 = 24 zmm8 to zmm31 * number of registers used for load B = 24/8 = 3 zmm0 to zmm2 - * number of regusters used for broadcast A = 2 zmm6 and zmm7 + * number of registers used for broadcast A = 2 zmm6 and zmm7 */ void bli_dgemm_zen4_asm_8x24( dim_t k_, @@ -703,10 +703,14 @@ void bli_dgemm_zen4_asm_8x24( [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", - "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", - "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", - "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", - "zmm30", "zmm31", "memory" + "r13", "k0", "k1", "k2", "k3", "xmm1", "xmm2", + "ymm2", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", + "ymm14", "ymm15", "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", + "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", + "ymm28", "ymm29", "ymm30", "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", + "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", + "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", + "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", + "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) } diff --git a/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c b/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c index 139edc7ddb..d5a10aa209 100644 --- a/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c +++ b/kernels/zen4/3/bli_gemmtrsm_l_zen4_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -156,7 +156,7 @@ static int64_t offsets[24] __attribute__((aligned(64))) = /* * number of accumulation registers = 24/8 * 8 = 24 zmm8 to zmm31 * number of registers used for load B = 24/8 = 3 zmm0 to zmm2 - * number of regusters used for broadcast A = 2 zmm6 and zmm7 + * number of registers used for broadcast A = 2 zmm6 and zmm7 */ void bli_dgemmtrsm_l_zen4_asm_8x24 ( @@ -812,7 +812,11 @@ void bli_dgemmtrsm_l_zen4_asm_8x24 [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "r13", "r14", "r15", "k0", "k1", "k2", "k3", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "ymm16", + "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", + "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", + "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", diff --git a/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c index 08edcb574f..9633fc5bf7 100644 --- a/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c +++ b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -411,7 +411,7 @@ void bli_dgemmtrsm_l_zen_asm_16x14 /* C prefetch Loop Note: This loop runs 14 times, - These 14 iterations are done seperately so that c11 can be prefetched here. + These 14 iterations are done separately so that c11 can be prefetched here. */ ADD(R11, RSI) ADD(IMM(14), RSI) diff --git a/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c b/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c index d1ea0109d7..e9dae78ba7 100644 --- a/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c +++ b/kernels/zen4/3/bli_gemmtrsm_u_zen4_8x24.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -156,7 +156,7 @@ static int64_t offsets[24] __attribute__((aligned(64))) = /* * number of accumulation registers = 24/8 * 8 = 24 zmm8 to zmm31 * number of registers used for load B = 24/8 = 3 zmm0 to zmm2 - * number of regusters used for broadcast A = 2 zmm6 and zmm7 + * number of registers used for broadcast A = 2 zmm6 and zmm7 */ void bli_dgemmtrsm_u_zen4_asm_8x24 ( @@ -817,7 +817,11 @@ void bli_dgemmtrsm_u_zen4_asm_8x24 [offsetPtr] "m" (offsetPtr) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", - "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "r13", "r14", "r15", "k0", "k1", "k2", "k3", "ymm8", "ymm9", + "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", "ymm16", + "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", + "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", + "ymm31", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", diff --git a/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c index 401c6e7d23..a57d8dacc4 100644 --- a/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c +++ b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -407,7 +407,7 @@ void bli_dgemmtrsm_u_zen_asm_16x14 /* C prefetch Loop Note: This loop runs 14 times, - These 14 iterations are done seperately so that c11 can be prefetched here. + These 14 iterations are done separately so that c11 can be prefetched here. */ ADD(R11, RSI) ADD(IMM(14), RSI) diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 82431bd6a2..3d10c3a9e4 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -729,7 +729,7 @@ err_t bli_trsm_small_mt_AVX512 // region - GEMM DTRSM for right variants #define BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ - /*K loop is broken into two seperate loops + /*K loop is broken into two separate loops each loop computes k/2 iterations */ \ \ int itr = (k_iter / 2); /*itr count for first loop*/\ @@ -900,7 +900,7 @@ err_t bli_trsm_small_mt_AVX512 */ #define BLIS_DTRSM_SMALL_GEMM_8nx4m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ - /*K loop is broken into two seperate loops + /*K loop is broken into two separate loops each loop computes k/2 iterations */ \ \ int itr = (k_iter / 2); /*itr count for first loop*/\ @@ -979,7 +979,7 @@ err_t bli_trsm_small_mt_AVX512 #define BLIS_DTRSM_SMALL_GEMM_8nx3m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ - /*K loop is broken into two seperate loops + /*K loop is broken into two separate loops each loop computes k/2 iterations */ \ \ int itr = (k_iter / 2); /*itr count for first loop*/\ @@ -1062,7 +1062,7 @@ err_t bli_trsm_small_mt_AVX512 ymm16 = _mm256_add_pd(ymm16, ymm31); #define BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ - /*K loop is broken into two seperate loops + /*K loop is broken into two separate loops each loop computes k/2 iterations */ \ \ int itr = (k_iter / 2); /*itr count for first loop*/\ @@ -1142,7 +1142,7 @@ err_t bli_trsm_small_mt_AVX512 ymm16 = _mm256_add_pd(ymm16, ymm31); #define BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ - /*K loop is broken into two seperate loops + /*K loop is broken into two separate loops each loop computes k/2 iterations */ \ \ int itr = (k_iter / 2); /*itr count for first loop*/\ @@ -2034,12 +2034,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 bli_rntm_init_from_global(&rntm); bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); + bli_pba_rntm_set_pba(&rntm); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ((d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -2047,7 +2047,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); // acquire memory for A01 pack @@ -4306,7 +4306,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if ((required_packing_A == 1) && bli_mem_is_alloc(&local_mem_buf_A_s)) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -4364,12 +4364,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 bli_rntm_init_from_global(&rntm); bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); + bli_pba_rntm_set_pba(&rntm); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ((d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -4377,7 +4377,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 if (required_packing_A) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); // acquire memory for A01 pack @@ -6606,7 +6606,7 @@ else if ( n_remainder == 2) if ((required_packing_A) && bli_mem_is_alloc(&local_mem_buf_A_s)) { - bli_membrk_release(&rntm, + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; @@ -7278,12 +7278,12 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 bli_rntm_init_from_global(&rntm); bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); + bli_pba_rntm_set_pba(&rntm); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ((d_mr * m * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -7291,7 +7291,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -9193,7 +9193,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 if ((required_packing_A == 1) && bli_mem_is_alloc(&local_mem_buf_A_s)) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } @@ -9245,12 +9245,12 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 bli_rntm_init_from_global(&rntm); bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); + bli_pba_rntm_set_pba(&rntm); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_pba_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + bli_rntm_pba(&rntm))); if ((d_mr * m * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; @@ -9258,7 +9258,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 if (required_packing_A == 1) { // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, + bli_pba_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, &local_mem_buf_A_s); @@ -11006,7 +11006,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 if ((required_packing_A == 1) && bli_mem_is_alloc(&local_mem_buf_A_s)) { - bli_membrk_release(&rntm, &local_mem_buf_A_s); + bli_pba_release(&rntm, &local_mem_buf_A_s); } return BLIS_SUCCESS; } diff --git a/kernels/zen4/3/bli_zero_zmm.c b/kernels/zen4/3/bli_zero_zmm.c new file mode 100644 index 0000000000..67ff9a62de --- /dev/null +++ b/kernels/zen4/3/bli_zero_zmm.c @@ -0,0 +1,62 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "bli_x86_asm_macros.h" +void bli_zero_zmm() +{ + + BEGIN_ASM() + VXORPD(ZMM(16), ZMM(16), ZMM(16)) + VXORPD(ZMM(17), ZMM(17), ZMM(17)) + VXORPD(ZMM(18), ZMM(18), ZMM(18)) + VXORPD(ZMM(19), ZMM(19), ZMM(19)) + VXORPD(ZMM(20), ZMM(20), ZMM(20)) + VXORPD(ZMM(21), ZMM(21), ZMM(21)) + VXORPD(ZMM(22), ZMM(22), ZMM(22)) + VXORPD(ZMM(23), ZMM(23), ZMM(23)) + VXORPD(ZMM(24), ZMM(24), ZMM(24)) + VXORPD(ZMM(25), ZMM(25), ZMM(25)) + VXORPD(ZMM(26), ZMM(26), ZMM(26)) + VXORPD(ZMM(27), ZMM(27), ZMM(27)) + VXORPD(ZMM(28), ZMM(28), ZMM(28)) + VXORPD(ZMM(29), ZMM(29), ZMM(29)) + VXORPD(ZMM(30), ZMM(30), ZMM(30)) + VXORPD(ZMM(31), ZMM(31), ZMM(31)) + + END_ASM (::: + "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", + "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", + "zmm29", "zmm30", "zmm31", "memory" + ) +} diff --git a/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c b/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c new file mode 100644 index 0000000000..fd0181c1d1 --- /dev/null +++ b/kernels/zen4/3/bli_zgemm_zen4_asm_4x12.c @@ -0,0 +1,565 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 6 +#define B_L1_PREFETCH_DIST 6 +#define TAIL_NITER 7 +// #define PREFETCH_A +#define PREFETCH_B +// #define PREFETCH_A_NEXT +#define PREFETCH_B_NEXT +#define PREFETCH_C // perfetch c in middle loop over 4 iterations of k + + +#ifdef PREFETCH_A + #define PREFETCH_A_L1(n) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*4*16 + 4*n*16)) +#else + #define PREFETCH_A_L1(n) +#endif + +#ifdef PREFETCH_B + #define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*12*16 + (12*n+(4*k))*16)) +#else + #define PREFETCH_B_L1(n, k) +#endif + + +/* + * A Registers: ZMM3, ZMM4, ZMM29, ZMM30 + * B Registers: ZMM0, ZMM1, ZMM2 + * C Registers: ZMM[8-28] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+2)*8)) \ + VFMADD231PD(ZMM(5) , ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(6) , ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(7) , ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+3)*8)) \ + VFMADD231PD(ZMM(8) , ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(9) , ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(10), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 0)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+4)*8)) \ + VFMADD231PD(ZMM(11), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(12), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(13), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+5)*8)) \ + VFMADD231PD(ZMM(14), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(15), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(16), ZMM(2), ZMM(4)) \ + \ + PREFETCH_B_L1(n, 1)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+6)*8)) \ + VFMADD231PD(ZMM(17), ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(18), ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(19), ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+7)*8)) \ + VFMADD231PD(ZMM(20), ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(21), ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(22), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 2)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+8)*8)) \ + VFMADD231PD(ZMM(23), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(24), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(25), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+9)*8)) \ + VFMADD231PD(ZMM(26), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(27), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(28), ZMM(2), ZMM(4)) \ + \ + VMOVAPD(ZMM(0), MEM(RBX, (12*n+0)*16)) \ + VMOVAPD(ZMM(1), MEM(RBX, (12*n+4)*16)) \ + VMOVAPD(ZMM(2), MEM(RBX, (12*n+8)*16)) + +#define SCALE_REG(a, b, c) \ + VPERMILPD(ZMM(3), a, IMM(0x55)) \ + VMULPD(a, a, b) \ + VMULPD(ZMM(3), ZMM(3), c) \ + VFMADDSUB132PD(a, ZMM(3), ZMM(31)) \ + +#define STORE_C_ROW(R1, R2, R3) \ + VMOVUPD(ZMM(0), MEM(RCX)) \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R1)) \ + VMOVUPD(MEM(RCX), ZMM(0)) \ + \ + VMOVUPD(ZMM(0), MEM(RCX, R10, 4)) \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R2)) \ + VMOVUPD(MEM(RCX, R10, 4), ZMM(0)) \ + \ + VMOVUPD(ZMM(0), MEM(RCX, R10, 8)) \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R3)) \ + VMOVUPD(MEM(RCX, R10, 8), ZMM(0)) \ + +#define LOAD_ROW_GEN() \ + VMOVUPD(XMM(0), MEM(RDX)) \ + VMOVUPD(XMM(27), MEM(RDX, R10, 1)) \ + VMOVUPD(XMM(28), MEM(RDX, R10, 2)) \ + VMOVUPD(XMM(29), MEM(RDX, R11, 1)) \ + VINSERTF64X2(ZMM(0), ZMM(0), XMM(27), IMM(0x1)) \ + VINSERTF64X2(ZMM(0), ZMM(0), XMM(28), IMM(0x2)) \ + VINSERTF64X2(ZMM(0), ZMM(0), XMM(29), IMM(0x3)) \ + +#define STORE_ROW_GEN() \ + VEXTRACTF64X2(XMM(27), ZMM(0), IMM(0x1)) \ + VEXTRACTF64X2(XMM(28), ZMM(0), IMM(0x2)) \ + VEXTRACTF64X2(XMM(29), ZMM(0), IMM(0x3)) \ + VMOVUPD(MEM(RDX) , XMM(0)) \ + VMOVUPD(MEM(RDX, R10, 1), XMM(27)) \ + VMOVUPD(MEM(RDX, R10, 2), XMM(28)) \ + VMOVUPD(MEM(RDX, R11, 1), XMM(29)) \ + +#define STORE_C_COL_GEN(R1, R2, R3) \ + MOV(RDX, RCX) \ + LEA(RCX, MEM(RCX, R12, 1)) \ + LOAD_ROW_GEN() \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R1)) \ + STORE_ROW_GEN() \ + LEA(RDX, MEM(RDX, R10, 4)) \ + \ + LOAD_ROW_GEN() \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R2)) \ + STORE_ROW_GEN() \ + LEA(RDX, MEM(RDX, R10, 4)) \ + \ + LOAD_ROW_GEN() \ + SCALE_REG(ZMM(0), ZMM(1), ZMM(2)) \ + VADDPD(ZMM(0), ZMM(0), ZMM(R3)) \ + STORE_ROW_GEN() \ + +/**********************************************************/ +/* Kernel : bli_zgemm_zen4_asm_4x12 */ +/* It performs C = C * beta + alpha * A * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : ZMM(3, 4, 29, 30) */ +/* load B : ZMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* ZMM(8-10, 14-16, 20-22, 26-28) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/**********************************************************/ +void bli_zgemm_zen4_asm_4x12( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + + + char beta_mul_type = BLIS_MUL_DEFAULT; + if(beta->imag == 0.0 && beta->real == 0.0 ) + { + beta_mul_type = BLIS_MUL_ZERO; + } + double one = 1; // used for FMADDSUB instruction + double *one_addr = &one; + + BEGIN_ASM() + + VXORPD(XMM(5) , XMM(5) , XMM(5) ) + VXORPD(XMM(6) , XMM(6) , XMM(6) ) + VXORPD(XMM(7) , XMM(7) , XMM(7) ) + VXORPD(XMM(8) , XMM(8) , XMM(8) ) + VXORPD(XMM(9) , XMM(9) , XMM(9) ) + VXORPD(XMM(10), XMM(10), XMM(10)) + VXORPD(XMM(11), XMM(11), XMM(11)) + VXORPD(XMM(12), XMM(12), XMM(12)) + VXORPD(XMM(13), XMM(13), XMM(13)) + VXORPD(XMM(14), XMM(14), XMM(14)) + VXORPD(XMM(15), XMM(15), XMM(15)) + VXORPD(XMM(16), XMM(16), XMM(16)) + VXORPD(XMM(17), XMM(17), XMM(17)) + VXORPD(XMM(18), XMM(18), XMM(18)) + VXORPD(XMM(19), XMM(19), XMM(19)) + VXORPD(XMM(20), XMM(20), XMM(20)) + VXORPD(XMM(21), XMM(21), XMM(21)) + VXORPD(XMM(22), XMM(22), XMM(22)) + VXORPD(XMM(23), XMM(23), XMM(23)) + VXORPD(XMM(24), XMM(24), XMM(24)) + VXORPD(XMM(25), XMM(25), XMM(25)) + VXORPD(XMM(26), XMM(26), XMM(26)) + VXORPD(XMM(27), XMM(27), XMM(27)) + VXORPD(XMM(28), XMM(28), XMM(28)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a)) //load address of a + MOV(RBX, VAR(b)) //load address of b + MOV(RCX, VAR(c)) //load address of c + + #ifdef PREFETCH_C + LEA(R9, MEM(RCX, 63)) // c for prefetch, first cache line + LEA(R8, MEM(R9, 128)) // c for prefetch, second cache line + #endif + + + VMOVAPD(ZMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(ZMM(1), MEM(RBX, 4*16)) //pre-load b + VMOVAPD(ZMM(2), MEM(RBX, 8*16)) //pre-load b + VBROADCASTSD(ZMM(29), MEM(RAX, 0)) + VBROADCASTSD(ZMM(30), MEM(RAX, 8)) + LEA(RBX, MEM(RBX, 12*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + #if defined PREFETCH_A_NEXT || defined PREFETCH_B_NEXT + MOV(RDI, RSI) + IMUL(RDI, IMM(16*4)) // rdi = k * 16*4 + #endif + + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(RAX, RDI, 1)) // r14(a_next) = A + (k*16*4) + #endif + + #ifdef PREFETCH_B_NEXT + IMUL(RDI, IMM(3)) // rdi = k * 16*12 + LEA(R15, MEM(RBX, RDI, 1)) // r15(b_next) = B + (k*16*12) + #endif + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* C_PREFETCH loop count: */ + /* LOOP1: k/4 - TAIL_NITER - 4 */ + /* LOOP2: 4 */ + /* LOOP4: TAIL_NITER */ + /* TAIL_LOOP: k%4 */ + /* */ + /* No prefetch loop count: */ + /* LOOP1: k/4 */ + /* TAIL_LOOP: k%4 */ + /************************************************************/ + #ifdef PREFETCH_C + /* prefetch c over 4 iterations of k*/ + SUB(RDI, IMM(4+TAIL_NITER)) + #endif + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + #ifdef PREFETCH_A_NEXT + PREFETCH(1, MEM(R14)) + #endif + SUBITER(0) + #ifdef PREFETCH_B_NEXT + PREFETCH(1, MEM(R15)) + #endif + SUBITER(1) + #ifdef PREFETCH_A_NEXT + PREFETCH(2, MEM(R14, 64)) + #endif + SUB(RDI, IMM(1)) + SUBITER(2) + #ifdef PREFETCH_B_NEXT + PREFETCH(2, MEM(R15, 64)) + #endif + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(R14,128)) + #endif + #ifdef PREFETCH_B_NEXT + LEA(R15, MEM(R15,64)) + #endif + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + +#ifdef PREFETCH_C + ADD(RDI, IMM(4)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + SUBITER(0) + PREFETCH(0, MEM(R9)) + SUBITER(1) + PREFETCH(0, MEM(R9, 64)) + SUB(RDI, IMM(1)) + PREFETCH(0, MEM(R9,128)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + LEA(R9, MEM(R9,R12,1)) + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP4) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + + JNZ(LOOP4) + +#endif //PREFETCH_C + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,4*16)) + LEA(RBX, MEM(RBX,12*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + VPERMILPD(ZMM8 , ZMM8 , IMM(0x55)) + VPERMILPD(ZMM9 , ZMM9 , IMM(0x55)) + VPERMILPD(ZMM10, ZMM10, IMM(0x55)) + VPERMILPD(ZMM14, ZMM14, IMM(0x55)) + VPERMILPD(ZMM15, ZMM15, IMM(0x55)) + VPERMILPD(ZMM16, ZMM16, IMM(0x55)) + VPERMILPD(ZMM20, ZMM20, IMM(0x55)) + VPERMILPD(ZMM21, ZMM21, IMM(0x55)) + VPERMILPD(ZMM22, ZMM22, IMM(0x55)) + VPERMILPD(ZMM26, ZMM26, IMM(0x55)) + VPERMILPD(ZMM27, ZMM27, IMM(0x55)) + VPERMILPD(ZMM28, ZMM28, IMM(0x55)) + + MOV(R8, VAR(one_addr)) + VBROADCASTSD(ZMM(31), MEM(R8)) + VFMADDSUB132PD(ZMM(5) , ZMM(8) , ZMM(31)) + VFMADDSUB132PD(ZMM(6) , ZMM(9) , ZMM(31)) + VFMADDSUB132PD(ZMM(7) , ZMM(10), ZMM(31)) + + VFMADDSUB132PD(ZMM(11), ZMM(14), ZMM(31)) + VFMADDSUB132PD(ZMM(12), ZMM(15), ZMM(31)) + VFMADDSUB132PD(ZMM(13), ZMM(16), ZMM(31)) + + VFMADDSUB132PD(ZMM(17), ZMM(20), ZMM(31)) + VFMADDSUB132PD(ZMM(18), ZMM(21), ZMM(31)) + VFMADDSUB132PD(ZMM(19), ZMM(22), ZMM(31)) + + VFMADDSUB132PD(ZMM(23), ZMM(26), ZMM(31)) + VFMADDSUB132PD(ZMM(24), ZMM(27), ZMM(31)) + VFMADDSUB132PD(ZMM(25), ZMM(28), ZMM(31)) + + MOV(RAX, VAR(alpha)) + VBROADCASTSD(ZMM(0), MEM(RAX)) + VBROADCASTSD(ZMM(1), MEM(RAX, 8)) + + SCALE_REG(ZMM(5) , ZMM(0), ZMM(1)) + SCALE_REG(ZMM(6) , ZMM(0), ZMM(1)) + SCALE_REG(ZMM(7) , ZMM(0), ZMM(1)) + + SCALE_REG(ZMM(11), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1)) + + SCALE_REG(ZMM(17), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1)) + + SCALE_REG(ZMM(23), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1)) + + MOV(RBX, VAR(beta)) + VBROADCASTSD(ZMM(1), MEM(RBX)) + VBROADCASTSD(ZMM(2), MEM(RBX, 8)) + + + MOV(AL, VAR(beta_mul_type)) + CMP(AL, IMM(0)) + JE(.ZBETAZERO) + + CMP(R10, IMM(16)) //CS == 1 IMPLIES ROW STORED + JNZ(.ZCOLSTORED) + + LABEL(.ZROWSTORED) + STORE_C_ROW(5 , 6 , 7 ) ADD(RCX, R12) + STORE_C_ROW(11, 12, 13) ADD(RCX, R12) + STORE_C_ROW(17, 18, 19) ADD(RCX, R12) + STORE_C_ROW(23, 24, 25) + JMP(.ZDONE) + + LABEL(.ZCOLSTORED) + LEA(R11, MEM(R10, R10, 2)) + STORE_C_COL_GEN(5, 6, 7) + STORE_C_COL_GEN(11, 12, 13) + STORE_C_COL_GEN(17, 18, 19) + STORE_C_COL_GEN(23, 24, 25) + JMP(.ZDONE) + + LABEL(.ZBETAZERO) + CMP(R10, IMM(16)) + JZ(.ZROWSTORBZ) + + LABEL(.ZCOLSTORBZ) + LEA(R11, MEM(R10, R10, 2)) + MOV(RDX, RCX) + ADD(RCX, R12) + VMOVUPD(ZMM(0), ZMM(5)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(6)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(7)) STORE_ROW_GEN() + + MOV(RDX, RCX) + LEA(RCX, MEM(RCX, R12, 1)) + VMOVUPD(ZMM(0), ZMM(11)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(12)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(13)) STORE_ROW_GEN() + + MOV(RDX, RCX) + LEA(RCX, MEM(RCX, R12, 1)) + VMOVUPD(ZMM(0), ZMM(17)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(18)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(19)) STORE_ROW_GEN() + + MOV(RDX, RCX) + VMOVUPD(ZMM(0), ZMM(23)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(24)) STORE_ROW_GEN() + LEA(RDX, MEM(RDX, R10, 4)) + VMOVUPD(ZMM(0), ZMM(25)) STORE_ROW_GEN() + + JMP(.ZDONE) + + + LABEL(.ZROWSTORBZ) + VMOVUPD(MEM(RCX ), ZMM(5)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(6)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(7)) + LEA(RCX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX ), ZMM(11)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(12)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(13)) + LEA(RCX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX ), ZMM(17)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(18)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(19)) + LEA(RCX, MEM(RCX, R12, 1)) + + VMOVUPD(MEM(RCX ), ZMM(23)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(24)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(25)) + + LABEL(.ZDONE) + + VZEROUPPER() + + END_ASM + ( + : // output operands (none) + : // input operands + [beta_mul_type] "m" (beta_mul_type), + [k] "m" (k), + [a] "m" (a), + [b] "m" (b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [one_addr] "m" (one_addr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", + "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", + "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", + "xmm14", "xmm15", "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", + "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", + "xmm27", "xmm28", "xmm29", "xmm30", "xmm31", + "memory" + ) +} diff --git a/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c b/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c new file mode 100644 index 0000000000..5fe475421e --- /dev/null +++ b/kernels/zen4/3/bli_zgemmtrsm_l_4x12.c @@ -0,0 +1,705 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 6 +#define B_L1_PREFETCH_DIST 6 +#define TAIL_NITER 7 +// #define PREFETCH_A +#define PREFETCH_B +// #define PREFETCH_A_NEXT +#define PREFETCH_B_NEXT +#define PREFETCH_C // perfetch c in middle loop over 4 iterations of k + + +#ifdef PREFETCH_A + #define PREFETCH_A_L1(n) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*4*16 + 4*n*16)) +#else + #define PREFETCH_A_L1(n) +#endif + +#ifdef PREFETCH_B + #define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*12*16 + (12*n+(4*k))*16)) +#else + #define PREFETCH_B_L1(n, k) +#endif + + +/* + * A Registers: ZMM3, ZMM4, ZMM29, ZMM30 + * B Registers: ZMM0, ZMM1, ZMM2 + * C Registers: ZMM[8-28] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+2)*8)) \ + VFMADD231PD(ZMM(5) , ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(6) , ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(7) , ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+3)*8)) \ + VFMADD231PD(ZMM(8) , ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(9) , ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(10), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 0)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+4)*8)) \ + VFMADD231PD(ZMM(11), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(12), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(13), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+5)*8)) \ + VFMADD231PD(ZMM(14), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(15), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(16), ZMM(2), ZMM(4)) \ + \ + PREFETCH_B_L1(n, 1)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+6)*8)) \ + VFMADD231PD(ZMM(17), ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(18), ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(19), ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+7)*8)) \ + VFMADD231PD(ZMM(20), ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(21), ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(22), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 2)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+8)*8)) \ + VFMADD231PD(ZMM(23), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(24), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(25), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+9)*8)) \ + VFMADD231PD(ZMM(26), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(27), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(28), ZMM(2), ZMM(4)) \ + \ + VMOVAPD(ZMM(0), MEM(RBX, (12*n+0)*16)) \ + VMOVAPD(ZMM(1), MEM(RBX, (12*n+4)*16)) \ + VMOVAPD(ZMM(2), MEM(RBX, (12*n+8)*16)) + +#define SCALE_REG(a, b, c, out) \ + VPERMILPD(ZMM(3), a, IMM(0x55)) \ + VMULPD(out, a, b) \ + VMULPD(ZMM(3), ZMM(3), c) \ + VFMADDSUB132PD(out, ZMM(3), ZMM(31)) \ + +#define DIVIDE_COMPLEX(R1, c, d, csq_dsq) \ + VPERMILPD(ZMM(3), R1, IMM(0x55)) \ + VMULPD(R1, R1, c) \ + VMULPD(ZMM(3), ZMM(3), d) \ + VMULPD(ZMM(3), ZMM(3), ZMM(2)) \ + VFMADDSUB132PD(R1, ZMM(3), ZMM(31)) \ + VDIVPD(R1, R1, csq_dsq) \ + +#define STORE_REG_GEN(reg) \ + VEXTRACTF64X2(XMM(27), ZMM(reg), IMM(0x1)) \ + VEXTRACTF64X2(XMM(28), ZMM(reg), IMM(0x2)) \ + VEXTRACTF64X2(XMM(29), ZMM(reg), IMM(0x3)) \ + VMOVUPD(MEM(RDX) , XMM(reg)) \ + VMOVUPD(MEM(RDX, R10, 1), XMM(27)) \ + VMOVUPD(MEM(RDX, R10, 2), XMM(28)) \ + VMOVUPD(MEM(RDX, R11, 1), XMM(29)) \ + + +/**********************************************************/ +/* Kernel : bli_zgemmtrsm_l_zen4_asm_4x12 */ +/* It performs C = C * beta + alpha * A * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : ZMM(3, 4, 29, 30) */ +/* load B : ZMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* ZMM(8-10, 14-16, 20-22, 26-28) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/**********************************************************/ +void bli_zgemmtrsm_l_zen4_asm_4x12( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a10, + dcomplex* restrict a11, + dcomplex* restrict b01, + dcomplex* restrict b11, + dcomplex* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + double one = 1; // used for FMADDSUB instruction + double neg_one = -1; // used for complex division + double *one_addr = &one; + double *neg_one_addr = &neg_one; + + BEGIN_ASM() + + VXORPD(XMM(5) , XMM(5) , XMM(5) ) + VXORPD(XMM(6) , XMM(6) , XMM(6) ) + VXORPD(XMM(7) , XMM(7) , XMM(7) ) + VXORPD(XMM(8) , XMM(8) , XMM(8) ) + VXORPD(XMM(9) , XMM(9) , XMM(9) ) + VXORPD(XMM(10), XMM(10), XMM(10)) + VXORPD(XMM(11), XMM(11), XMM(11)) + VXORPD(XMM(12), XMM(12), XMM(12)) + VXORPD(XMM(13), XMM(13), XMM(13)) + VXORPD(XMM(14), XMM(14), XMM(14)) + VXORPD(XMM(15), XMM(15), XMM(15)) + VXORPD(XMM(16), XMM(16), XMM(16)) + VXORPD(XMM(17), XMM(17), XMM(17)) + VXORPD(XMM(18), XMM(18), XMM(18)) + VXORPD(XMM(19), XMM(19), XMM(19)) + VXORPD(XMM(20), XMM(20), XMM(20)) + VXORPD(XMM(21), XMM(21), XMM(21)) + VXORPD(XMM(22), XMM(22), XMM(22)) + VXORPD(XMM(23), XMM(23), XMM(23)) + VXORPD(XMM(24), XMM(24), XMM(24)) + VXORPD(XMM(25), XMM(25), XMM(25)) + VXORPD(XMM(26), XMM(26), XMM(26)) + VXORPD(XMM(27), XMM(27), XMM(27)) + VXORPD(XMM(28), XMM(28), XMM(28)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a10)) //load address of a + MOV(RBX, VAR(b01)) //load address of b + MOV(RCX, VAR(b11)) //load address of c + MOV(R9, VAR(c11)) //load address of c + MOV(R11, VAR(neg_one_addr)) + + #ifdef PREFETCH_C + LEA(R9, MEM(R9, 63)) // c for prefetch, first cache line + LEA(R8, MEM(R9, 128)) // c for prefetch, second cache line + #endif + + + VMOVAPD(ZMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(ZMM(1), MEM(RBX, 4*16)) //pre-load b + VMOVAPD(ZMM(2), MEM(RBX, 8*16)) //pre-load b + VBROADCASTSD(ZMM(29), MEM(RAX, 0)) + VBROADCASTSD(ZMM(30), MEM(RAX, 8)) + LEA(RBX, MEM(RBX, 12*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + #if defined PREFETCH_A_NEXT || defined PREFETCH_B_NEXT + MOV(RDI, RSI) + IMUL(RDI, IMM(16*4)) // rdi = k * 16*4 + #endif + + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(RAX, RDI, 1)) // r14(a_next) = A + (k*16*4) + #endif + + #ifdef PREFETCH_B_NEXT + IMUL(RDI, IMM(3)) // rdi = k * 16*12 + LEA(R15, MEM(RBX, RDI, 1)) // r15(b_next) = B + (k*16*12) + #endif + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* C_PREFETCH loop count: */ + /* LOOP1: k/4 - TAIL_NITER - 4 */ + /* LOOP2: 4 */ + /* LOOP4: TAIL_NITER */ + /* TAIL_LOOP: k%4 */ + /* */ + /* No prefetch loop count: */ + /* LOOP1: k/4 */ + /* TAIL_LOOP: k%4 */ + /************************************************************/ + #ifdef PREFETCH_C + /* prefetch c over 4 iterations of k*/ + SUB(RDI, IMM(4+TAIL_NITER)) + #endif + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + #ifdef PREFETCH_A_NEXT + PREFETCH(1, MEM(R14)) + #endif + SUBITER(0) + #ifdef PREFETCH_B_NEXT + PREFETCH(1, MEM(R15)) + #endif + SUBITER(1) + #ifdef PREFETCH_A_NEXT + PREFETCH(2, MEM(R14, 64)) + #endif + SUB(RDI, IMM(1)) + SUBITER(2) + #ifdef PREFETCH_B_NEXT + PREFETCH(2, MEM(R15, 64)) + #endif + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(R14,128)) + #endif + #ifdef PREFETCH_B_NEXT + LEA(R15, MEM(R15,64)) + #endif + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + +#ifdef PREFETCH_C + ADD(RDI, IMM(4)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + SUBITER(0) + PREFETCH(0, MEM(R9)) + SUBITER(1) + PREFETCH(0, MEM(R9, 64)) + SUB(RDI, IMM(1)) + PREFETCH(0, MEM(R9,128)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + LEA(R9, MEM(R9,R12,1)) + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP4) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + + JNZ(LOOP4) + +#endif //PREFETCH_C + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,4*16)) + LEA(RBX, MEM(RBX,12*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + /******************************************************/ + /* Permute imag component register. Shuffle even */ + /* and odd components */ + /* SRC: ZMM8 =(Ai0*Br0, Ai0*Bi0, Ai0*Br1, Ai0*Bi1, ..)*/ + /* DST: ZMM8 =(Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1, ..)*/ + /******************************************************/ + VPERMILPD(ZMM8 , ZMM8 , IMM(0x55)) + VPERMILPD(ZMM9 , ZMM9 , IMM(0x55)) + VPERMILPD(ZMM10, ZMM10, IMM(0x55)) + VPERMILPD(ZMM14, ZMM14, IMM(0x55)) + VPERMILPD(ZMM15, ZMM15, IMM(0x55)) + VPERMILPD(ZMM16, ZMM16, IMM(0x55)) + VPERMILPD(ZMM20, ZMM20, IMM(0x55)) + VPERMILPD(ZMM21, ZMM21, IMM(0x55)) + VPERMILPD(ZMM22, ZMM22, IMM(0x55)) + VPERMILPD(ZMM26, ZMM26, IMM(0x55)) + VPERMILPD(ZMM27, ZMM27, IMM(0x55)) + VPERMILPD(ZMM28, ZMM28, IMM(0x55)) + + /*******************************************************/ + /* SRC: ZMM5 = (Ar0*Br0, Ar0*Bi0, Ar0*Br1, Ar0*Bi1, ..)*/ + /* SRC: ZMM8 = (Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1, ..)*/ + /* DST: ZMM8 =(Ar0*Br0-Ai0*Bi0, Ai0*Br0+Ar0*Bi0, */ + /* Ar0*Br1-Ai0*Bi1, Ai0*Br1+Ar0*Bi1, ..) */ + /*******************************************************/ + MOV(R8, VAR(one_addr)) + VBROADCASTSD(ZMM(31), MEM(R8)) + VFMADDSUB132PD(ZMM(5) , ZMM(8) , ZMM(31)) + VFMADDSUB132PD(ZMM(6) , ZMM(9) , ZMM(31)) + VFMADDSUB132PD(ZMM(7) , ZMM(10), ZMM(31)) + + VFMADDSUB132PD(ZMM(11), ZMM(14), ZMM(31)) + VFMADDSUB132PD(ZMM(12), ZMM(15), ZMM(31)) + VFMADDSUB132PD(ZMM(13), ZMM(16), ZMM(31)) + + VFMADDSUB132PD(ZMM(17), ZMM(20), ZMM(31)) + VFMADDSUB132PD(ZMM(18), ZMM(21), ZMM(31)) + VFMADDSUB132PD(ZMM(19), ZMM(22), ZMM(31)) + + VFMADDSUB132PD(ZMM(23), ZMM(26), ZMM(31)) + VFMADDSUB132PD(ZMM(24), ZMM(27), ZMM(31)) + VFMADDSUB132PD(ZMM(25), ZMM(28), ZMM(31)) + + MOV(RAX, VAR(alpha)) + VBROADCASTSD(ZMM(0), MEM(RAX)) + VBROADCASTSD(ZMM(1), MEM(RAX, 8)) + MOV(RDX, RCX) + MOV(RDI, IMM(12*16)) + + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + /*****************************/ + /* gemm_output -= C * alpha */ + /*****************************/ + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(5), ZMM(14), ZMM(5)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(6), ZMM(15), ZMM(6)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(7), ZMM(16), ZMM(7)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(11), ZMM(14), ZMM(11)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(12), ZMM(15), ZMM(12)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(13), ZMM(16), ZMM(13)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(17), ZMM(14), ZMM(17)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(18), ZMM(15), ZMM(18)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(19), ZMM(16), ZMM(19)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(23), ZMM(14), ZMM(23)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(24), ZMM(15), ZMM(24)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(25), ZMM(16), ZMM(25)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + + + //REGION - TRSM + + MOV(RAX, VAR(a11)) + //iteration 0 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (0+0*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (0+0*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + /*****************/ + /* C = C * A11 */ + /*****************/ + SCALE_REG(ZMM(5), ZMM(0), ZMM(1), ZMM(5)) + SCALE_REG(ZMM(6), ZMM(0), ZMM(1), ZMM(6)) + SCALE_REG(ZMM(7), ZMM(0), ZMM(1), ZMM(7)) + #else + /**************************************************************/ + /* C = C / A11 */ + /* */ + /* Let C / A11 = (a + ib) / (c + id) = */ + /* ((ac + bd) / (c^2 + d^2)) + i ((bc - ad) / (c^2+d^2)) */ + /**************************************************************/ + VBROADCASTSD(ZMM(2), MEM(R11)) // -1 + VMULPD(ZMM(8), ZMM(0), ZMM(0)) // c*c + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) // c*c + d*d + + DIVIDE_COMPLEX(ZMM(5), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(6), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(7), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(5)) + VMOVUPD(MEM(RCX, 4*16), ZMM(6)) + VMOVUPD(MEM(RCX, 8*16), ZMM(7)) + ADD(RCX, RDI) + + //iteration 1 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (1+0*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (1+0*4)*16+8)) + SCALE_REG(ZMM(5), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(6), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(7), ZMM(0), ZMM(1), ZMM(16)) + + VSUBPD(ZMM(11), ZMM(11), ZMM(14)) + VSUBPD(ZMM(12), ZMM(12), ZMM(15)) + VSUBPD(ZMM(13), ZMM(13), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (1+1*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (1+1*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(11), ZMM(0), ZMM(1), ZMM(11)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1), ZMM(12)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1), ZMM(13)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(11), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(12), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(13), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(11)) + VMOVUPD(MEM(RCX, 4*16), ZMM(12)) + VMOVUPD(MEM(RCX, 8*16), ZMM(13)) + ADD(RCX, RDI) + + //iteration 2 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (2+0*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (2+0*4)*16+8)) + SCALE_REG(ZMM(5), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(6), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(7), ZMM(0), ZMM(1), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (2+1*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (2+1*4)*16+8)) + SCALE_REG(ZMM(11), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VSUBPD(ZMM(17), ZMM(17), ZMM(14)) + VSUBPD(ZMM(18), ZMM(18), ZMM(15)) + VSUBPD(ZMM(19), ZMM(19), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (2+2*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (2+2*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(17), ZMM(0), ZMM(1), ZMM(17)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1), ZMM(18)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1), ZMM(19)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(17), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(18), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(19), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(17)) + VMOVUPD(MEM(RCX, 4*16), ZMM(18)) + VMOVUPD(MEM(RCX, 8*16), ZMM(19)) + ADD(RCX, RDI) + + //iteration 3 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (3+0*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (3+0*4)*16+8)) + SCALE_REG(ZMM(5), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(6), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(7), ZMM(0), ZMM(1), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (3+1*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (3+1*4)*16+8)) + SCALE_REG(ZMM(11), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (3+2*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (3+2*4)*16+8)) + SCALE_REG(ZMM(17), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VSUBPD(ZMM(23), ZMM(23), ZMM(14)) + VSUBPD(ZMM(24), ZMM(24), ZMM(15)) + VSUBPD(ZMM(25), ZMM(25), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (3+3*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (3+3*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(23), ZMM(0), ZMM(1), ZMM(23)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1), ZMM(24)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1), ZMM(25)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(23), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(24), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(25), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(23)) + VMOVUPD(MEM(RCX, 4*16), ZMM(24)) + VMOVUPD(MEM(RCX, 8*16), ZMM(25)) + +// ENDREGION - TRSM + + MOV(RCX, VAR(c11)) + CMP(R10, IMM(16)) //CS == 1 IMPLIES ROW STORED + JNZ(.ZCOLSTORED) + + LABEL(.ZROWSTORED) + VMOVUPD(MEM(RCX ), ZMM(5)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(6)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(7)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(11)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(12)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(13)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(17)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(18)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(19)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(23)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(24)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(25)) + + JMP(.ZDONE) + + LABEL(.ZCOLSTORED) + LEA(R11, MEM(R10, R10, 2)) + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(5) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(6) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(7) + + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(11) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(12) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(13) + + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(17) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(18) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(19) + + MOV(RDX, RCX) + STORE_REG_GEN(23) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(24) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(25) + + LABEL(.ZDONE) + VZEROUPPER() + + END_ASM + ( + : // output operands (none) + : // input operands + [a10] "m" (a10), + [k] "m" (k), + [b01] "m" (b01), + [a11] "m" (a11), + [b11] "m" (b11), + [c11] "m" (c11), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [alpha] "m" (alpha), + [neg_one_addr] "m" (neg_one_addr), + [one_addr] "m" (one_addr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", + "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", + "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", + "xmm14", "xmm15", "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", + "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", + "xmm27", "xmm28", "xmm29", "xmm30", "xmm31", + "memory" + ) +} diff --git a/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c b/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c new file mode 100644 index 0000000000..8e86e2040c --- /dev/null +++ b/kernels/zen4/3/bli_zgemmtrsm_u_4x12.c @@ -0,0 +1,715 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 6 +#define B_L1_PREFETCH_DIST 6 +#define TAIL_NITER 7 +// #define PREFETCH_A +#define PREFETCH_B +// #define PREFETCH_A_NEXT +#define PREFETCH_B_NEXT +#define PREFETCH_C // perfetch c in middle loop over 4 iterations of k + + +#ifdef PREFETCH_A + #define PREFETCH_A_L1(n) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*4*16 + 4*n*16)) +#else + #define PREFETCH_A_L1(n) +#endif + +#ifdef PREFETCH_B + #define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*12*16 + (12*n+(4*k))*16)) +#else + #define PREFETCH_B_L1(n, k) +#endif + + +/* + * A Registers: ZMM3, ZMM4, ZMM29, ZMM30 + * B Registers: ZMM0, ZMM1, ZMM2 + * C Registers: ZMM[8-28] + */ + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+2)*8)) \ + VFMADD231PD(ZMM(5) , ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(6) , ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(7) , ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+3)*8)) \ + VFMADD231PD(ZMM(8) , ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(9) , ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(10), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 0)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+4)*8)) \ + VFMADD231PD(ZMM(11), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(12), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(13), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+5)*8)) \ + VFMADD231PD(ZMM(14), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(15), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(16), ZMM(2), ZMM(4)) \ + \ + PREFETCH_B_L1(n, 1)\ + VBROADCASTSD(ZMM(3), MEM(RAX, (8*n+6)*8)) \ + VFMADD231PD(ZMM(17), ZMM(0), ZMM(29)) \ + VFMADD231PD(ZMM(18), ZMM(1), ZMM(29)) \ + VFMADD231PD(ZMM(19), ZMM(2), ZMM(29)) \ + VBROADCASTSD(ZMM(4), MEM(RAX, (8*n+7)*8)) \ + VFMADD231PD(ZMM(20), ZMM(0), ZMM(30)) \ + VFMADD231PD(ZMM(21), ZMM(1), ZMM(30)) \ + VFMADD231PD(ZMM(22), ZMM(2), ZMM(30)) \ + \ + PREFETCH_B_L1(n, 2)\ + VBROADCASTSD(ZMM(29), MEM(RAX, (8*n+8)*8)) \ + VFMADD231PD(ZMM(23), ZMM(0), ZMM(3)) \ + VFMADD231PD(ZMM(24), ZMM(1), ZMM(3)) \ + VFMADD231PD(ZMM(25), ZMM(2), ZMM(3)) \ + VBROADCASTSD(ZMM(30), MEM(RAX, (8*n+9)*8)) \ + VFMADD231PD(ZMM(26), ZMM(0), ZMM(4)) \ + VFMADD231PD(ZMM(27), ZMM(1), ZMM(4)) \ + VFMADD231PD(ZMM(28), ZMM(2), ZMM(4)) \ + \ + VMOVAPD(ZMM(0), MEM(RBX, (12*n+0)*16)) \ + VMOVAPD(ZMM(1), MEM(RBX, (12*n+4)*16)) \ + VMOVAPD(ZMM(2), MEM(RBX, (12*n+8)*16)) + +#define SCALE_REG(a, b, c, out) \ + VPERMILPD(ZMM(3), a, IMM(0x55)) \ + VMULPD(out, a, b) \ + VMULPD(ZMM(3), ZMM(3), c) \ + VFMADDSUB132PD(out, ZMM(3), ZMM(31)) \ + +#define DIVIDE_COMPLEX(R1, c, d, csq_dsq) \ + VPERMILPD(ZMM(3), R1, IMM(0x55)) \ + VMULPD(R1, R1, c) \ + VMULPD(ZMM(3), ZMM(3), d) \ + VMULPD(ZMM(3), ZMM(3), ZMM(2)) \ + VFMADDSUB132PD(R1, ZMM(3), ZMM(31)) \ + VDIVPD(R1, R1, csq_dsq) \ + +#define STORE_REG_GEN(reg) \ + VEXTRACTF64X2(XMM(27), ZMM(reg), IMM(0x1)) \ + VEXTRACTF64X2(XMM(28), ZMM(reg), IMM(0x2)) \ + VEXTRACTF64X2(XMM(29), ZMM(reg), IMM(0x3)) \ + VMOVUPD(MEM(RDX) , XMM(reg)) \ + VMOVUPD(MEM(RDX, R10, 1), XMM(27)) \ + VMOVUPD(MEM(RDX, R10, 2), XMM(28)) \ + VMOVUPD(MEM(RDX, R11, 1), XMM(29)) \ + + +/**********************************************************/ +/* Kernel : bli_zgemmtrsm_l_zen4_asm_4x12 */ +/* It performs C = C * beta + alpha * A * B */ +/* It is row preferred kernel, A and B are packed */ +/* C could be Row/Col/Gen Stored Matrix */ +/* Registers are allocated as below */ +/* Broadcast A : ZMM(3, 4, 29, 30) */ +/* load B : ZMM(0, 1, 2) */ +/* Accumulation of B(real,imag)*Areal : */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/* Accumulation of B(real,imag)*Aimag : */ +/* ZMM(8-10, 14-16, 20-22, 26-28) */ +/* Computation of A(real,imag)*B(real,imag): */ +/* ZMM(5-7 , 11-13, 17-19, 23-25) */ +/**********************************************************/ +void bli_zgemmtrsm_u_zen4_asm_4x12( + dim_t k_, + dcomplex* restrict alpha, + dcomplex* restrict a10, + dcomplex* restrict a11, + dcomplex* restrict b01, + dcomplex* restrict b11, + dcomplex* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + const int64_t k = k_; + /*rowstride * size of one dcomplex element*/ + const int64_t rs_c = rs_c_*16; + /*colstride * size of one dcomplex element*/ + const int64_t cs_c = cs_c_*16; + double one = 1; // used for FMADDSUB instruction + double neg_one = -1; // used for complex division + double *one_addr = &one; + double *neg_one_addr = &neg_one; + + BEGIN_ASM() + + VXORPD(XMM(5) , XMM(5) , XMM(5) ) + VXORPD(XMM(6) , XMM(6) , XMM(6) ) + VXORPD(XMM(7) , XMM(7) , XMM(7) ) + VXORPD(XMM(8) , XMM(8) , XMM(8) ) + VXORPD(XMM(9) , XMM(9) , XMM(9) ) + VXORPD(XMM(10), XMM(10), XMM(10)) + VXORPD(XMM(11), XMM(11), XMM(11)) + VXORPD(XMM(12), XMM(12), XMM(12)) + VXORPD(XMM(13), XMM(13), XMM(13)) + VXORPD(XMM(14), XMM(14), XMM(14)) + VXORPD(XMM(15), XMM(15), XMM(15)) + VXORPD(XMM(16), XMM(16), XMM(16)) + VXORPD(XMM(17), XMM(17), XMM(17)) + VXORPD(XMM(18), XMM(18), XMM(18)) + VXORPD(XMM(19), XMM(19), XMM(19)) + VXORPD(XMM(20), XMM(20), XMM(20)) + VXORPD(XMM(21), XMM(21), XMM(21)) + VXORPD(XMM(22), XMM(22), XMM(22)) + VXORPD(XMM(23), XMM(23), XMM(23)) + VXORPD(XMM(24), XMM(24), XMM(24)) + VXORPD(XMM(25), XMM(25), XMM(25)) + VXORPD(XMM(26), XMM(26), XMM(26)) + VXORPD(XMM(27), XMM(27), XMM(27)) + VXORPD(XMM(28), XMM(28), XMM(28)) + + MOV(RSI, VAR(k)) //loop index + MOV(RAX, VAR(a10)) //load address of a + MOV(RBX, VAR(b01)) //load address of b + MOV(RCX, VAR(b11)) //load address of c + MOV(R9, VAR(c11)) //load address of c + MOV(R11, VAR(neg_one_addr)) + + #ifdef PREFETCH_C + LEA(R9, MEM(R9, 63)) // c for prefetch, first cache line + LEA(R8, MEM(R9, 128)) // c for prefetch, second cache line + #endif + + + VMOVAPD(ZMM(0), MEM(RBX, 0*16)) //pre-load b + VMOVAPD(ZMM(1), MEM(RBX, 4*16)) //pre-load b + VMOVAPD(ZMM(2), MEM(RBX, 8*16)) //pre-load b + VBROADCASTSD(ZMM(29), MEM(RAX, 0)) + VBROADCASTSD(ZMM(30), MEM(RAX, 8)) + LEA(RBX, MEM(RBX, 12*16)) //adjust b for pre-load + + MOV(R12, VAR(rs_c)) + MOV(R10, VAR(cs_c)) + + #if defined PREFETCH_A_NEXT || defined PREFETCH_B_NEXT + MOV(RDI, RSI) + IMUL(RDI, IMM(16*4)) // rdi = k * 16*4 + #endif + + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(RAX, RDI, 1)) // r14(a_next) = A + (k*16*4) + #endif + + #ifdef PREFETCH_B_NEXT + IMUL(RDI, IMM(3)) // rdi = k * 16*12 + LEA(R15, MEM(RBX, RDI, 1)) // r15(b_next) = B + (k*16*12) + #endif + + MOV(RDI, RSI) + AND(RSI, IMM(3)) + SAR(RDI, IMM(2)) + /************************************************************/ + /* Operation: */ + /* SUBITER = (Ar, Ai)*(Br, Bi) = Ar*(Br, Bi) , Ai*(Br, Bi) */ + /* C_PREFETCH loop count: */ + /* LOOP1: k/4 - TAIL_NITER - 4 */ + /* LOOP2: 4 */ + /* LOOP4: TAIL_NITER */ + /* TAIL_LOOP: k%4 */ + /* */ + /* No prefetch loop count: */ + /* LOOP1: k/4 */ + /* TAIL_LOOP: k%4 */ + /************************************************************/ + #ifdef PREFETCH_C + /* prefetch c over 4 iterations of k*/ + SUB(RDI, IMM(4+TAIL_NITER)) + #endif + JLE(K_PREFETCH_C) + + LOOP_ALIGN + LABEL(LOOP1) + #ifdef PREFETCH_A_NEXT + PREFETCH(1, MEM(R14)) + #endif + SUBITER(0) + #ifdef PREFETCH_B_NEXT + PREFETCH(1, MEM(R15)) + #endif + SUBITER(1) + #ifdef PREFETCH_A_NEXT + PREFETCH(2, MEM(R14, 64)) + #endif + SUB(RDI, IMM(1)) + SUBITER(2) + #ifdef PREFETCH_B_NEXT + PREFETCH(2, MEM(R15, 64)) + #endif + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + #ifdef PREFETCH_A_NEXT + LEA(R14, MEM(R14,128)) + #endif + #ifdef PREFETCH_B_NEXT + LEA(R15, MEM(R15,64)) + #endif + + JNZ(LOOP1) + + LABEL(K_PREFETCH_C) + +#ifdef PREFETCH_C + ADD(RDI, IMM(4)) + JLE(K_TAIL_NITER) + + LOOP_ALIGN + LABEL(LOOP2) + SUBITER(0) + PREFETCH(0, MEM(R9)) + SUBITER(1) + PREFETCH(0, MEM(R9, 64)) + SUB(RDI, IMM(1)) + PREFETCH(0, MEM(R9,128)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + LEA(R9, MEM(R9,R12,1)) + JNZ(LOOP2) + + LABEL(K_TAIL_NITER) + + ADD(RDI, IMM(0+TAIL_NITER)) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP4) + + SUBITER(0) + SUBITER(1) + SUB(RDI, IMM(1)) + SUBITER(2) + SUBITER(3) + + LEA(RAX, MEM(RAX,4*4*16)) + LEA(RBX, MEM(RBX,4*12*16)) + + JNZ(LOOP4) + +#endif //PREFETCH_C + + LABEL(TAIL) + + TEST(RSI, RSI) + JZ(POSTACCUM) + + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(RSI, IMM(1)) + SUBITER(0) + LEA(RAX, MEM(RAX,4*16)) + LEA(RBX, MEM(RBX,12*16)) + + JNZ(TAIL_LOOP) + + LABEL(POSTACCUM) + + /******************************************************/ + /* Permute imag component register. Shuffle even */ + /* and odd components */ + /* SRC: ZMM8 =(Ai0*Br0, Ai0*Bi0, Ai0*Br1, Ai0*Bi1, ..)*/ + /* DST: ZMM8 =(Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1, ..)*/ + /******************************************************/ + VPERMILPD(ZMM8 , ZMM8 , IMM(0x55)) + VPERMILPD(ZMM9 , ZMM9 , IMM(0x55)) + VPERMILPD(ZMM10, ZMM10, IMM(0x55)) + VPERMILPD(ZMM14, ZMM14, IMM(0x55)) + VPERMILPD(ZMM15, ZMM15, IMM(0x55)) + VPERMILPD(ZMM16, ZMM16, IMM(0x55)) + VPERMILPD(ZMM20, ZMM20, IMM(0x55)) + VPERMILPD(ZMM21, ZMM21, IMM(0x55)) + VPERMILPD(ZMM22, ZMM22, IMM(0x55)) + VPERMILPD(ZMM26, ZMM26, IMM(0x55)) + VPERMILPD(ZMM27, ZMM27, IMM(0x55)) + VPERMILPD(ZMM28, ZMM28, IMM(0x55)) + + /*******************************************************/ + /* SRC: ZMM5 = (Ar0*Br0, Ar0*Bi0, Ar0*Br1, Ar0*Bi1, ..)*/ + /* SRC: ZMM8 = (Ai0*Bi0, Ai0*Br0, Ai0*Bi1, Ai0*Br1, ..)*/ + /* DST: ZMM8 =(Ar0*Br0-Ai0*Bi0, Ai0*Br0+Ar0*Bi0, */ + /* Ar0*Br1-Ai0*Bi1, Ai0*Br1+Ar0*Bi1, ..) */ + /*******************************************************/ + MOV(R8, VAR(one_addr)) + VBROADCASTSD(ZMM(31), MEM(R8)) + VFMADDSUB132PD(ZMM(5) , ZMM(8) , ZMM(31)) + VFMADDSUB132PD(ZMM(6) , ZMM(9) , ZMM(31)) + VFMADDSUB132PD(ZMM(7) , ZMM(10), ZMM(31)) + + VFMADDSUB132PD(ZMM(11), ZMM(14), ZMM(31)) + VFMADDSUB132PD(ZMM(12), ZMM(15), ZMM(31)) + VFMADDSUB132PD(ZMM(13), ZMM(16), ZMM(31)) + + VFMADDSUB132PD(ZMM(17), ZMM(20), ZMM(31)) + VFMADDSUB132PD(ZMM(18), ZMM(21), ZMM(31)) + VFMADDSUB132PD(ZMM(19), ZMM(22), ZMM(31)) + + VFMADDSUB132PD(ZMM(23), ZMM(26), ZMM(31)) + VFMADDSUB132PD(ZMM(24), ZMM(27), ZMM(31)) + VFMADDSUB132PD(ZMM(25), ZMM(28), ZMM(31)) + + MOV(RAX, VAR(alpha)) + VBROADCASTSD(ZMM(0), MEM(RAX)) + VBROADCASTSD(ZMM(1), MEM(RAX, 8)) + MOV(RDX, RCX) + MOV(RDI, IMM(12*16)) + + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + /*****************************/ + /* gemm_output -= C * alpha */ + /*****************************/ + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(5), ZMM(14), ZMM(5)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(6), ZMM(15), ZMM(6)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(7), ZMM(16), ZMM(7)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(11), ZMM(14), ZMM(11)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(12), ZMM(15), ZMM(12)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(13), ZMM(16), ZMM(13)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + ADD(RDX, RDI) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(17), ZMM(14), ZMM(17)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(18), ZMM(15), ZMM(18)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(19), ZMM(16), ZMM(19)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + + + SCALE_REG(ZMM(14) , ZMM(0), ZMM(1), ZMM(14)) + VSUBPD(ZMM(23), ZMM(14), ZMM(23)) + VMOVUPD(ZMM(14), MEM(RDX, 0*16)) + + SCALE_REG(ZMM(15) , ZMM(0), ZMM(1), ZMM(15)) + VSUBPD(ZMM(24), ZMM(15), ZMM(24)) + VMOVUPD(ZMM(15), MEM(RDX, 4*16)) + + SCALE_REG(ZMM(16) , ZMM(0), ZMM(1), ZMM(16)) + VSUBPD(ZMM(25), ZMM(16), ZMM(25)) + VMOVUPD(ZMM(16), MEM(RDX, 8*16)) + + + //REGION - TRSM + + MOV(RAX, VAR(a11)) + LEA(RCX, MEM(RCX, RDI, 2)) + ADD(RCX, RDI) + //iteration 0 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (3+3*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (3+3*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(23), ZMM(0), ZMM(1), ZMM(23)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1), ZMM(24)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1), ZMM(25)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(23), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(24), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(25), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(23)) + VMOVUPD(MEM(RCX, 4*16), ZMM(24)) + VMOVUPD(MEM(RCX, 8*16), ZMM(25)) + SUB(RCX, RDI) + + //iteration 1 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (2+3*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (2+3*4)*16+8)) + SCALE_REG(ZMM(23), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1), ZMM(16)) + + VSUBPD(ZMM(17), ZMM(17), ZMM(14)) + VSUBPD(ZMM(18), ZMM(18), ZMM(15)) + VSUBPD(ZMM(19), ZMM(19), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (2+2*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (2+2*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(17), ZMM(0), ZMM(1), ZMM(17)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1), ZMM(18)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1), ZMM(19)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(17), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(18), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(19), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(17)) + VMOVUPD(MEM(RCX, 4*16), ZMM(18)) + VMOVUPD(MEM(RCX, 8*16), ZMM(19)) + SUB(RCX, RDI) + + //iteration 2 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (1+3*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (1+3*4)*16+8)) + SCALE_REG(ZMM(23), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (1+2*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (1+2*4)*16+8)) + SCALE_REG(ZMM(17), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VSUBPD(ZMM(11), ZMM(11), ZMM(14)) + VSUBPD(ZMM(12), ZMM(12), ZMM(15)) + VSUBPD(ZMM(13), ZMM(13), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (1+1*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (1+1*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(11), ZMM(0), ZMM(1), ZMM(11)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1), ZMM(12)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1), ZMM(13)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + DIVIDE_COMPLEX(ZMM(11), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(12), ZMM(0), ZMM(1), ZMM(8)) + DIVIDE_COMPLEX(ZMM(13), ZMM(0), ZMM(1), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(11)) + VMOVUPD(MEM(RCX, 4*16), ZMM(12)) + VMOVUPD(MEM(RCX, 8*16), ZMM(13)) + SUB(RCX, RDI) + + //iteration 3 ----------------------------------- + VBROADCASTSD(ZMM(0), MEM(RAX, (0+3*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (0+3*4)*16+8)) + SCALE_REG(ZMM(23), ZMM(0), ZMM(1), ZMM(14)) + SCALE_REG(ZMM(24), ZMM(0), ZMM(1), ZMM(15)) + SCALE_REG(ZMM(25), ZMM(0), ZMM(1), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (0+2*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (0+2*4)*16+8)) + SCALE_REG(ZMM(17), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(18), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(19), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (0+1*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (0+1*4)*16+8)) + SCALE_REG(ZMM(11), ZMM(0), ZMM(1), ZMM(20)) + SCALE_REG(ZMM(12), ZMM(0), ZMM(1), ZMM(21)) + SCALE_REG(ZMM(13), ZMM(0), ZMM(1), ZMM(22)) + VADDPD(ZMM(14), ZMM(14), ZMM(20)) + VADDPD(ZMM(15), ZMM(15), ZMM(21)) + VADDPD(ZMM(16), ZMM(16), ZMM(22)) + + VSUBPD(ZMM(5), ZMM(5), ZMM(14)) + VSUBPD(ZMM(6), ZMM(6), ZMM(15)) + VSUBPD(ZMM(7), ZMM(7), ZMM(16)) + + VBROADCASTSD(ZMM(0), MEM(RAX, (0+0*4)*16+0)) + VBROADCASTSD(ZMM(1), MEM(RAX, (0+0*4)*16+8)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + SCALE_REG(ZMM(5), ZMM(0), ZMM(1), ZMM(5)) + SCALE_REG(ZMM(6), ZMM(0), ZMM(1), ZMM(6)) + SCALE_REG(ZMM(7), ZMM(0), ZMM(1), ZMM(7)) + #else + VBROADCASTSD(ZMM(2), MEM(R11)) + VMULPD(ZMM(8), ZMM(0), ZMM(0)) + VFMADD231PD(ZMM(8), ZMM(1), ZMM(1)) + + VPERMILPD(ZMM(3), ZMM(5), IMM(0x55)) + VMULPD(ZMM(5), ZMM(5), ZMM(0)) + VMULPD(ZMM(3), ZMM(3), ZMM(1)) + VMULPD(ZMM(3), ZMM(3), ZMM(2)) + VFMADDSUB132PD(ZMM(5), ZMM(3), ZMM(31)) + VDIVPD(ZMM(5), ZMM(5), ZMM(8)) + + VPERMILPD(ZMM(3), ZMM(6), IMM(0x55)) + VMULPD(ZMM(6), ZMM(6), ZMM(0)) + VMULPD(ZMM(3), ZMM(3), ZMM(1)) + VMULPD(ZMM(3), ZMM(3), ZMM(2)) + VFMADDSUB132PD(ZMM(6), ZMM(3), ZMM(31)) + VDIVPD(ZMM(6), ZMM(6), ZMM(8)) + + VPERMILPD(ZMM(3), ZMM(7), IMM(0x55)) + VMULPD(ZMM(7), ZMM(7), ZMM(0)) + VMULPD(ZMM(3), ZMM(3), ZMM(1)) + VMULPD(ZMM(3), ZMM(3), ZMM(2)) + VFMADDSUB132PD(ZMM(7), ZMM(3), ZMM(31)) + VDIVPD(ZMM(7), ZMM(7), ZMM(8)) + #endif + VMOVUPD(MEM(RCX, 0*16), ZMM(5)) + VMOVUPD(MEM(RCX, 4*16), ZMM(6)) + VMOVUPD(MEM(RCX, 8*16), ZMM(7)) + +// ENDREGION - TRSM + + MOV(RCX, VAR(c11)) + CMP(R10, IMM(16)) //CS == 1 IMPLIES ROW STORED + JNZ(.ZCOLSTORED) + + LABEL(.ZROWSTORED) + VMOVUPD(MEM(RCX ), ZMM(5)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(6)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(7)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(11)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(12)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(13)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(17)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(18)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(19)) + ADD(RCX, R12) + + VMOVUPD(MEM(RCX ), ZMM(23)) + VMOVUPD(MEM(RCX, R10, 4), ZMM(24)) + VMOVUPD(MEM(RCX, R10, 8), ZMM(25)) + + JMP(.ZDONE) + + LABEL(.ZCOLSTORED) + LEA(R11, MEM(R10, R10, 2)) + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(5) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(6) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(7) + + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(11) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(12) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(13) + + MOV(RDX, RCX) + ADD(RCX, R12) + STORE_REG_GEN(17) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(18) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(19) + + MOV(RDX, RCX) + STORE_REG_GEN(23) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(24) LEA(RDX, MEM(RDX, R10, 4)) + STORE_REG_GEN(25) + + LABEL(.ZDONE) + VZEROUPPER() + + END_ASM + ( + : // output operands (none) + : // input operands + [a10] "m" (a10), + [k] "m" (k), + [b01] "m" (b01), + [a11] "m" (a11), + [b11] "m" (b11), + [c11] "m" (c11), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c), + [alpha] "m" (alpha), + [neg_one_addr] "m" (neg_one_addr), + [one_addr] "m" (one_addr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", + "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", + "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", + "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", + "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", + "xmm14", "xmm15", "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", + "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", + "xmm27", "xmm28", "xmm29", "xmm30", "xmm31", + "memory" + ) +} diff --git a/kernels/zen4/3/sup/CMakeLists.txt b/kernels/zen4/3/sup/CMakeLists.txt deleted file mode 100644 index 81e194ef64..0000000000 --- a/kernels/zen4/3/sup/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen4_3sup - OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_s6x64.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_s6x64.h - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_s6x64m.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rd_zen_s6x64n.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_s6x64.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_s6x64.h - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_s6x64m.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_rv_zen_s6x64n.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_24x8m.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_cv_zen4_z12x4m.c -) -target_compile_options(zen4_3sup PRIVATE /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen4_3sup PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() - -add_subdirectory(d24x8) diff --git a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c index 97ac0985dc..649aa416b5 100644 --- a/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c +++ b/kernels/zen4/3/sup/bli_dgemmsup_rv_zen4_asm_24x8m.c @@ -6079,6 +6079,18 @@ void bli_dgemmsup_rv_zen4_asm_24x4m vxorpd(zmm12, zmm12, zmm12) vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm27,zmm27, zmm27) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) + vxorpd(zmm23, zmm23, zmm23) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm25, zmm25, zmm25) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -6089,7 +6101,22 @@ void bli_dgemmsup_rv_zen4_asm_24x4m jle(.PREFETCHLOOP) // jump if i <= 0 label(.LOOP1) - + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26, zmm12, zmm13, zmm27 + * to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, zmm25 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26, zmm12, zmm13, zmm27. + */ // ---------------------------------- iteration 1 vmovupd( mem(rax),zmm0 ) // load A @@ -6138,21 +6165,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 3 @@ -6198,21 +6225,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 5 @@ -6256,21 +6283,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 7 @@ -6310,21 +6337,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -6383,21 +6410,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -6442,21 +6469,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -6498,21 +6525,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -6550,28 +6577,28 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 jnz(.LOOP2) // iterate again if i != 0. label(.TAILITER) add(imm(TAIL_NITER), rsi) // i += TAIL_NITER - jle(.TAIL) // jump if i <= 0 + jle(.TAIL) // jump if i <= 0 label(.LOOP3) @@ -6621,21 +6648,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -6679,21 +6706,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -6735,21 +6762,21 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -6787,25 +6814,37 @@ void bli_dgemmsup_rv_zen4_asm_24x4m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) - vfmadd231pd( zmm5,zmm31,zmm27 ) + vfmadd231pd( zmm3,zmm31,zmm23 ) + vfmadd231pd( zmm4,zmm31,zmm24 ) + vfmadd231pd( zmm5,zmm31,zmm25 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm26, zmm26) + vaddpd(zmm23, zmm12, zmm12) + vaddpd(zmm24, zmm13, zmm13) + vaddpd(zmm25, zmm27, zmm27) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -7177,6 +7216,15 @@ void bli_dgemmsup_rv_zen4_asm_24x3m vxorpd(zmm10, zmm10, zmm10) vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -7186,6 +7234,22 @@ void bli_dgemmsup_rv_zen4_asm_24x3m sub(imm( 3+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21, zmm22 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -7232,17 +7296,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 @@ -7283,17 +7347,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 @@ -7333,17 +7397,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 @@ -7379,17 +7443,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -7444,17 +7508,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) - add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) + add( r8,rbx ) // b += rs_b + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -7494,17 +7558,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -7542,17 +7606,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -7586,17 +7650,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -7649,17 +7713,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -7698,17 +7762,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -7746,17 +7810,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -7790,21 +7854,30 @@ void bli_dgemmsup_rv_zen4_asm_24x3m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm26, zmm26) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -8157,6 +8230,12 @@ void bli_dgemmsup_rv_zen4_asm_24x2m vxorpd(zmm8, zmm8, zmm8) vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -8166,6 +8245,21 @@ void bli_dgemmsup_rv_zen4_asm_24x2m sub(imm( 2+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -8208,13 +8302,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 @@ -8250,13 +8344,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 @@ -8292,13 +8386,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 @@ -8330,13 +8424,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -8387,13 +8481,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -8428,13 +8522,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -8468,13 +8562,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -8504,13 +8598,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -8559,13 +8653,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -8599,13 +8693,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -8639,13 +8733,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -8675,17 +8769,23 @@ void bli_dgemmsup_rv_zen4_asm_24x2m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -9022,7 +9122,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m vxorpd(zmm6, zmm6, zmm6) vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm28, zmm28, zmm28) - + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, //one cacheline of B_next is prefetched where b_next = b + (unroll)*rs_b @@ -9031,6 +9133,20 @@ void bli_dgemmsup_rv_zen4_asm_24x1m sub(imm( 1+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -9068,9 +9184,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 3 @@ -9102,9 +9218,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 5 @@ -9136,9 +9252,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 7 @@ -9166,9 +9282,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -9214,9 +9330,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -9247,9 +9363,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -9279,9 +9395,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -9307,9 +9423,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -9353,9 +9469,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -9385,9 +9501,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -9417,9 +9533,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -9445,13 +9561,16 @@ void bli_dgemmsup_rv_zen4_asm_24x1m add( r10,r14 ) // a_next += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c b/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c index a0db7fd504..4fc04901ca 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_cv_zen4_z12x4m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2023, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -916,7 +916,7 @@ void bli_zgemmsup_cv_zen4_asm_12x4m const double *v = &value; // Assigning the type of alpha and beta scaling - // In order to facilitate handling special cases seperately + // In order to facilitate handling special cases separately char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; @@ -1400,7 +1400,7 @@ void bli_zgemmsup_cv_zen4_asm_12x3m const double *v = &value; // Assigning the type of alpha and beta scaling - // In order to facilitate handling special cases seperately + // In order to facilitate handling special cases separately char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; @@ -1707,7 +1707,7 @@ void bli_zgemmsup_cv_zen4_asm_12x3m "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) consider_edge_cases: @@ -1819,7 +1819,7 @@ void bli_zgemmsup_cv_zen4_asm_12x2m const double *v = &value; // Assigning the type of alpha and beta scaling - // In order to facilitate handling special cases seperately + // In order to facilitate handling special cases separately char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; @@ -2112,7 +2112,7 @@ void bli_zgemmsup_cv_zen4_asm_12x2m "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) consider_edge_cases: @@ -2224,7 +2224,7 @@ void bli_zgemmsup_cv_zen4_asm_12x1m */ // Assigning the type of alpha and beta scaling - // In order to facilitate handling special cases seperately + // In order to facilitate handling special cases separately char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; @@ -2501,7 +2501,7 @@ void bli_zgemmsup_cv_zen4_asm_12x1m "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) consider_edge_cases:; @@ -3056,7 +3056,7 @@ void bli_zgemmsup_cv_zen4_asm_8x3 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -3301,7 +3301,7 @@ void bli_zgemmsup_cv_zen4_asm_8x2 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -3538,7 +3538,7 @@ void bli_zgemmsup_cv_zen4_asm_8x1 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -3992,7 +3992,7 @@ void bli_zgemmsup_cv_zen4_asm_4x3 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -4216,7 +4216,7 @@ void bli_zgemmsup_cv_zen4_asm_4x2 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -4433,7 +4433,7 @@ void bli_zgemmsup_cv_zen4_asm_4x1 "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", - "memory" + "k3", "memory" ) } @@ -5151,11 +5151,11 @@ void bli_zgemmsup_cv_zen4_asm_2x3 [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "al", + "xmm9", "xmm10", "xmm11", "xmm12", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", - "ymm12", "ymm13", "ymm14", "ymm15", - "memory" + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } @@ -5679,10 +5679,10 @@ void bli_zgemmsup_cv_zen4_asm_2x1 [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "al", + "xmm5", "xmm6", "xmm14", "xmm15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", - "ymm12", "ymm13", "ymm14", "ymm15", - "memory" + "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) } diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c index 23f43052e8..2e55b698ca 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64m.c @@ -53,9 +53,6 @@ - m0 and n0 are at most MR (6) and NR (64), respectively. Therefore, this (r)ow-preferential kernel is well-suited for contiguous (v)ector loads on B and single-element broadcasts from A. - - NOTE: These kernels currently do not have in-register transpose - implemented and hence they do not support column-oriented IO. */ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ( @@ -95,7 +92,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 data, cntx ); cij += nr_cur * cs_c0; - bj += nr_cur * cs_b0; + bj += nr_cur * cs_b0; n_left -= nr_cur; } @@ -111,7 +108,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 data, cntx ); cij += nr_cur * cs_c0; - bj += nr_cur * cs_b0; + bj += nr_cur * cs_b0; n_left -= nr_cur; } @@ -127,7 +124,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 data, cntx ); cij += nr_cur * cs_c0; - bj += nr_cur * cs_b0; + bj += nr_cur * cs_b0; n_left -= nr_cur; } @@ -143,7 +140,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 data, cntx ); cij += nr_cur * cs_c0; - bj += nr_cur * cs_b0; + bj += nr_cur * cs_b0; n_left -= nr_cur; } @@ -195,21 +192,21 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 else { const dim_t mr = 6; - + // Since A is packed into row panels, // we must use a loop over gemv. dim_t m_iter = ( m0 + mr - 1 ) / mr; dim_t m_left = m0 % mr; - + float* restrict ai_ii = ai; float* restrict cij_ii = cij; - + for ( dim_t ii = 0; ii < m_iter; ii += 1 ) { dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - - bli_sgemv_ex + + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -217,7 +214,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ); cij_ii += mr_cur * rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -241,6 +238,10 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -474,18 +475,18 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 * 4x16 & 2x16 each. * These smaller 4x16 & 2x16 tiles are transposed to 16x4 & 16x2 tiles, * to get the transpose of 6x64 tile and are stored as 64x6 tile. - * - * |-----------------------------------| |------------------|--------| - * | | | | | | | | - * | | | | | | 16x4 | 16x2 | - * | 4x16 | 4x16 | 4x16 | 4x16 | | | | - * | | | | | |------------------|--------| - * | | | | | | | | - * |-----------------------------------| -> | 16x4 | 16x2 | - * | | | | | | | | - * | 2x16 | 2x16 | 2x16 | 2x16 | |------------------|--------| - * | | | | | | | | - * |-----------------------------------| | 16x4 | 16x2 | + * + * |-----------------------------------| |------------------|--------| + * | | | | | | | | + * | | | | | | 16x4 | 16x2 | + * | 4x16 | 4x16 | 4x16 | 4x16 | | | | + * | | | | | |------------------|--------| + * | | | | | | | | + * |-----------------------------------| -> | 16x4 | 16x2 | + * | | | | | | | | + * | 2x16 | 2x16 | 2x16 | 2x16 | |------------------|--------| + * | | | | | | | | + * |-----------------------------------| | 16x4 | 16x2 | * | | | * |------------------|--------| * | | | @@ -495,9 +496,9 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -510,7 +511,11 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + + TRANSPOSE_2X16( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16( 25, 29 ) @@ -553,7 +558,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 mov( var( cbuf ), rcx ) // load address of c mov( var( cs_c ), rdi ) // load rs_c lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -561,11 +566,14 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) + TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16_BZ( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 25, 29 ) @@ -579,13 +587,12 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 label( .SDONE ) - lea( mem( , r8, 2 ), rdx ) // rdx = rs_a * 2 - lea( mem( rdx, r8, 4 ), rdx ) // rdx = rs_a * 6 + mov( var( ps_a4 ), rdx ) // load panel stride of a; rdx = ps_a4 mov( var( abuf ), rax ) // load address of a - add( rdx, rax ) // a += rs_a * 6(MR) + add( rdx, rax ) // a += ps_a4 mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) // load rs_c; rdi = rs_c lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -604,6 +611,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -639,7 +647,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 const dim_t i_edge = m0 - ( dim_t )m_left; float* restrict cij = c + i_edge * rs_c; - float* restrict ai = a + i_edge * rs_a; + float* restrict ai = a + m_iter * ps_a; float* restrict bj = b; if ( 4 <= m_left ) @@ -658,7 +666,7 @@ void bli_sgemmsup_rv_zen_asm_6x64m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -724,6 +732,10 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -914,7 +926,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 jne( .K_LEFT_LOOP ) // if rsi != 0, repeat k-loop - label(.SPOSTACCUM) + label(.SPOSTACCUM) // Scaling A * B with alpha. ALPHA_SCALE3( 7, 8, 9, 10 ) ALPHA_SCALE3( 7, 12, 13, 14 ) @@ -955,9 +967,9 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 * to get the transpose of 6x64 tile and are stored as 64x6 tile. */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -967,7 +979,10 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 lea( mem( rcx, r12, 4 ), rcx ) mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16( 25, 29 ) @@ -1005,9 +1020,9 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 * to get the transpose of 6x64 tile and are stored as 64x6 tile. */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c /* Transposing 4x16 tiles to 16x4 tiles */ TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) @@ -1018,7 +1033,10 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16_BZ( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 25, 29 ) @@ -1030,13 +1048,12 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 label( .SDONE ) - lea( mem( , r8, 2 ), rdx ) // rdx = rs_a * 2 - lea( mem( rdx, r8, 4 ), rdx ) // rdx = rs_a * 6 + mov( var( ps_a4 ), rdx ) // load panel stride of a mov( var( abuf ), rax ) // load address of a - add( rdx, rax ) // a += rs_a * 6(MR) + add( rdx, rax ) // a += ps_a4 mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) // load rs_c; rdi = rs_c lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -1055,6 +1072,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -1090,7 +1108,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 const dim_t i_edge = m0 - ( dim_t )m_left; float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + i_edge*rs_a; + float* restrict ai = a + m_iter * ps_a; float* restrict bj = b; if ( 4 <= m_left ) @@ -1109,7 +1127,7 @@ void bli_sgemmsup_rv_zen_asm_6x48m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -1175,6 +1193,10 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -1327,10 +1349,10 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 label( .CONSID_K_LEFT ) - mov( var( k_left ), rsi ) // i = k_left; - test( rsi, rsi ) // check i via logical AND. - je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + mov( var( k_left ), rsi ) // i = k_left; + test( rsi, rsi ) // check i via logical AND. + je( .SPOSTACCUM ) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. label( .K_LEFT_LOOP ) // Load 2 rows from B matrix. @@ -1397,9 +1419,9 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 * to get the transpose of 6x64 tile and are stored as 64x6 tile. */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c /* Transposing 4x16 tiles to 16x4 tiles */ TRANSPOSE_4X16( 8, 12, 16, 20 ) @@ -1409,7 +1431,11 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + + TRANSPOSE_2X16( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16( 25, 29 ) @@ -1445,9 +1471,9 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -1455,7 +1481,10 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16_BZ( 24, 28 ) lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 25, 29 ) @@ -1465,13 +1494,12 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 label( .SDONE ) - lea( mem( , r8, 2 ), rdx ) // rdx = rs_a * 2 - lea( mem( rdx, r8, 4 ), rdx ) // rdx = rs_a * 6 + mov( var( ps_a4 ), rdx ) // load panel stride of a mov( var( abuf ), rax ) // load address of a - add( rdx, rax ) // a += rs_a * 6(MR) + add( rdx, rax ) // a += ps_a4 mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) // load rs_c; rdi = rs_c lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -1490,6 +1518,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -1525,7 +1554,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 const dim_t i_edge = m0 - ( dim_t )m_left; float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + i_edge*rs_a; + float* restrict ai = a + m_iter * ps_a; float* restrict bj = b; if ( 4 <= m_left ) @@ -1544,7 +1573,7 @@ void bli_sgemmsup_rv_zen_asm_6x32m_avx512 ai += mr_cur * rs_a; m_left -= mr_cur; } - + if ( 2 <= m_left ) { const dim_t mr_cur = 2; @@ -1610,6 +1639,10 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a4 = ps_a * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -1830,16 +1863,19 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16( 24, 28 ) jmp( .SDONE ) // jump to the end @@ -1873,15 +1909,18 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - lea( mem( rcx, r10, 4 ), rcx ) + mov( var( rs_c ), r12 ) // load rs_c; r12 = rs_c + lea( mem( , r12, 4 ), r12 ) // r12 = rs_c*sizeof(dt) => r12 = rs_c*4 + lea( mem( rcx, r12, 4 ), rcx ) // rcx += 4 * r12 => rcx = 4 * rs_c + TRANSPOSE_2X16_BZ( 24, 28 ) jmp( .SDONE ) // jump to the end @@ -1889,13 +1928,12 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 label( .SDONE ) - lea( mem( , r8, 2 ), rdx ) // rdx = rs_a * 2 - lea( mem( rdx, r8, 4 ), rdx ) // rdx = rs_a * 6 + mov( var( ps_a4 ), rdx ) // load panel stride of a mov( var( abuf ), rax ) // load address of a - add( rdx, rax ) // a += rs_a * 6(MR) + add( rdx, rax ) // a += ps_a4 mov( rax, var( abuf ) ) // store updated a - mov( var( rs_c ), rdi ) + mov( var( rs_c ), rdi ) // load rs_c; rdi = rs_c lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 lea( mem( , rdi, 2 ), rdx ) // rdx = rs_c * 2 lea( mem( rdx, rdi, 4 ), rdx ) // rdx = rdi * 4 => rdx = rs_c * 6 @@ -1914,6 +1952,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 [a] "m" (a), [rs_a] "m" (rs_a), [cs_a] "m" (cs_a), + [ps_a4] "m" (ps_a4), [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), @@ -1949,7 +1988,7 @@ void bli_sgemmsup_rv_zen_asm_6x16m_avx512 const dim_t i_edge = m0 - ( dim_t )m_left; float* restrict cij = c + i_edge*rs_c; - float* restrict ai = a + i_edge*rs_a; + float* restrict ai = a + m_iter*ps_a; float* restrict bj = b; if ( 4 <= m_left ) @@ -2275,7 +2314,7 @@ void bli_sgemmsup_rv_zen_asm_4x64m_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) + TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) jmp( .SDONE ) // jump to the end diff --git a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c index e4ce3d1490..08204eef20 100644 --- a/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c +++ b/kernels/zen4/3/sup/bli_gemmsup_rv_zen_s6x64n.c @@ -41,21 +41,18 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : Assumptions: - B is row-stored; - A is row-stored; - m0 and n0 are at most MR (6) and NR (64), respectively. Therefore, this (r)ow-preferential kernel is well-suited for contiguous (v)ector loads on B and single-element broadcasts from A. - - NOTE: These kernels currently do not have in-register transpose - implemented and hence they do not support column-oriented IO. */ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 ( @@ -139,6 +136,10 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -360,9 +361,9 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -415,9 +416,9 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 */ /* Transposing 4x16 tiles to 16x4 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_4X16_BZ( 8, 12, 16, 20 ) lea( mem( rcx, r12, 4 ), rcx ) @@ -425,7 +426,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) + TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c @@ -438,26 +439,24 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 lea( mem( rcx, rdi, 2 ), rcx ) TRANSPOSE_2X16_BZ( 27, 31 ) - jmp( .SDONE ) // jump to the end + jmp( .SDONE ) // jump to the end label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4; rdx = ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) - mov( var( cs_c ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 => rdx = cs_c * 64 - mov( var( cbuf ), rcx ) // load address of c - add( rdx, rcx ) // c += rs_c * MR - mov( rcx, var( cbuf ) ) // store updated c + mov( var( cs_c ), rdx ) // load cs_c; rdx = cs_c + lea( mem( , rdx, 4 ), rdx ) // rdx = cs_c*sizeof(dt) => rdx = cs_c*4 + lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 + lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 + // => rdx = cs_c * 64 + mov( var( cbuf ), rcx ) // load address of c + add( rdx, rcx ) // c += rs_c * MR + mov( rcx, var( cbuf ) ) // store updated c dec( r11 ) jne( .N_LOOP_ITER ) @@ -473,6 +472,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -523,7 +523,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -539,7 +539,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -555,7 +555,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -571,7 +571,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -637,7 +637,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -645,7 +645,7 @@ void bli_sgemmsup_rv_zen_asm_6x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -681,6 +681,10 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -741,7 +745,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 VFMA4( 4, 20, 21, 22, 23 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA4( 5, 24, 25, 26, 27 ) - + add( r9, rbx ) add( r10, rax ) @@ -763,7 +767,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 VFMA4( 4, 20, 21, 22, 23 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA4( 5, 24, 25, 26, 27 ) - + add( r9, rbx ) add( r10, rax ) @@ -785,7 +789,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 VFMA4( 4, 20, 21, 22, 23 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA4( 5, 24, 25, 26, 27 ) - + add( r9, rbx ) add( r10, rax ) @@ -807,7 +811,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 VFMA4( 4, 20, 21, 22, 23 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA4( 5, 24, 25, 26, 27 ) - + add( r9, rbx ) add( r10, rax ) @@ -842,7 +846,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 VFMA4( 4, 20, 21, 22, 23 ) vbroadcastss( mem( rax, r8, 4 ), zmm5 ) VFMA4( 5, 24, 25, 26, 27 ) - + add( r9, rbx ) add( r10, rax ) dec( rsi ) @@ -958,7 +962,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) + TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) /* Transposing 1x16 tiles to 16x1 tiles */ mov( var( cbuf ), rcx ) // load address of c @@ -979,12 +983,9 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) mov( var( cs_c ), rdx ) @@ -1009,6 +1010,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -1059,7 +1061,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1075,7 +1077,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1091,7 +1093,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1107,7 +1109,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1173,7 +1175,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -1181,7 +1183,7 @@ void bli_sgemmsup_rv_zen_asm_5x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -1217,6 +1219,10 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -1466,28 +1472,26 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 lea( mem( rcx, r12, 4 ), rcx ) TRANSPOSE_4X16_BZ( 10, 14, 18, 22 ) lea( mem( rcx, r12, 4 ), rcx ) - TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) + TRANSPOSE_4X16_BZ( 11, 15, 19, 23 ) - jmp( .SDONE ) // jump to the end + jmp( .SDONE ) // jump to the end label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4; rdx = ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) - mov( var( cs_c ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 => rdx = cs_c * 64 - mov( var( cbuf ), rcx ) // load address of c - add( rdx, rcx ) // c += rs_c * MR - mov( rcx, var( cbuf ) ) // store updated c + mov( var( cs_c ), rdx ) // load cs_c; rdx = cs_c + lea( mem( , rdx, 4 ), rdx ) // rdx = cs_c*sizeof(dt) => rdx = cs_c*4 + lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 + lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 + // => rdx = cs_c * 64 + mov( var( cbuf ), rcx ) // load address of c + add( rdx, rcx ) // c += rs_c * MR + mov( rcx, var( cbuf ) ) // store updated c dec( r11 ) jne( .N_LOOP_ITER ) @@ -1503,6 +1507,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -1553,7 +1558,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1569,7 +1574,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1585,7 +1590,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1601,7 +1606,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -1667,7 +1672,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -1675,7 +1680,7 @@ void bli_sgemmsup_rv_zen_asm_4x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -1711,6 +1716,10 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -1937,9 +1946,9 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c TRANSPOSE_2X16_BZ( 8, 12 ) lea( mem( rcx, rdi, 2 ), rcx ) @@ -1950,39 +1959,37 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 TRANSPOSE_2X16_BZ( 11, 15 ) /* Transposing 1x16 tiles to 16x1 tiles */ - mov( var( cbuf ), rcx ) - mov( var( rs_c ), rdi ) - lea( mem( , rdi, 4 ), rdi ) - lea( mem( rcx, rdi, 2 ), rcx ) - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cbuf ), rcx ) // load address of c + mov( var( rs_c ), rdi ) // load rs_c; rdi = rs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c*sizeof(dt) => rdi = rs_c*4 + lea( mem( rcx, rdi, 2 ), rcx ) // c += rdi * 2 + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c UPDATE_C_1X16_BZ( 16 ) UPDATE_C_1X16_BZ( 17 ) UPDATE_C_1X16_BZ( 18 ) UPDATE_C_1X16_BZ( 19 ) - jmp( .SDONE ) // jump to the end + jmp( .SDONE ) // jump to the end label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) - mov( var( cs_c ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 => rdx = cs_c * 64 - mov( var( cbuf ), rcx ) // load address of c - add( rdx, rcx ) // c += rs_c * MR - mov( rcx, var( cbuf ) ) // store updated c + mov( var( cs_c ), rdx ) // load cs_c; rdx = cs_c + lea( mem( , rdx, 4 ), rdx ) // rdx = cs_c*sizeof(dt) => rdx = cs_c*4 + lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 + lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 + // => rdx = cs_c * 64 + mov( var( cbuf ), rcx ) // load address of c + add( rdx, rcx ) // c += rs_c * MR + mov( rcx, var( cbuf ) ) // store updated c dec( r11 ) jne( .N_LOOP_ITER ) @@ -1998,6 +2005,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -2048,7 +2056,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2064,7 +2072,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2080,7 +2088,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2096,7 +2104,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2162,7 +2170,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -2170,7 +2178,7 @@ void bli_sgemmsup_rv_zen_asm_3x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -2206,6 +2214,10 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -2422,12 +2434,9 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) mov( var( cs_c ), rdx ) @@ -2452,6 +2461,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -2502,7 +2512,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2518,7 +2528,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2534,7 +2544,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2550,7 +2560,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2616,7 +2626,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -2624,7 +2634,7 @@ void bli_sgemmsup_rv_zen_asm_2x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } @@ -2660,6 +2670,10 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + // Query the panel stride of B and convert it to units of bytes. + uint64_t ps_b = bli_auxinfo_ps_b( data ); + uint64_t ps_b4 = ps_b * sizeof( float ); + float *abuf = a; float *bbuf = b; float *cbuf = c; @@ -2806,7 +2820,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 label( .SROWSTORED ) UPDATE_C4( 4, 8, 9, 10, 11 ) - + jmp( .SDONE ) // jump to the end @@ -2814,9 +2828,9 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 /* Transposing 1x16 tiles to 16x1 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rdi = rs_c *= sizeof(dt) => rs_c *= 4 - lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * rs_c + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c UPDATE_C_1X16( 8 ) UPDATE_C_1X16( 9 ) @@ -2843,9 +2857,9 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 /* Transposing 2x16 tiles to 16x2 tiles */ mov( var( cbuf ), rcx ) // load address of c - mov( var( cs_c ), rdi ) // load rs_c - lea( mem( , rdi, 4 ), rdi ) // rs_c *= sizeof(float) - lea( mem( rdi, rdi, 2 ), r12 ) + mov( var( cs_c ), rdi ) // load cs_c; rdi = cs_c + lea( mem( , rdi, 4 ), rdi ) // rdi = cs_c*sizeof(dt) => rdi = cs_c*4 + lea( mem( rdi, rdi, 2 ), r12 ) // rdi += rdi * 2 => rdi = 3 * cs_c UPDATE_C_1X16_BZ( 8 ) UPDATE_C_1X16_BZ( 9 ) @@ -2857,21 +2871,19 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 label( .SDONE ) - mov( var( cs_b ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_b * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx += cs_b * 8 => rdx = cs_b * 16 - mov( var( bbuf ), rbx ) - add( rdx, rbx ) + mov( var( ps_b4 ), rdx ) // load ps_b4 + mov( var( bbuf ), rbx ) // load b + add( rdx, rbx ) // b += ps_b4 mov( rbx, var( bbuf ) ) - mov( var( cs_c ), rdx ) - lea( mem( , rdx, 4 ), rdx ) - lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 - lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 => rdx = cs_c * 64 - mov( var( cbuf ), rcx ) // load address of c - add( rdx, rcx ) // c += rs_c * MR - mov( rcx, var( cbuf ) ) // store updated c + mov( var( cs_c ), rdx ) // load cs_c; rdx = cs_c + lea( mem( , rdx, 4 ), rdx ) // rdx = cs_c*sizeof(dt) => rdx = cs_c*4 + lea( mem( , rdx, 8 ), rdx ) // rdx = cs_c * 8 + lea( mem( , rdx, 8 ), rdx ) // rdx = rdx * 8 = cs_c * 8 * 8 + // => rdx = cs_c * 64 + mov( var( cbuf ), rcx ) // load address of c + add( rdx, rcx ) // c += rs_c * MR + mov( rcx, var( cbuf ) ) // store updated c dec( r11 ) jne( .N_LOOP_ITER ) @@ -2887,6 +2899,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 [b] "m" (b), [rs_b] "m" (rs_b), [cs_b] "m" (cs_b), + [ps_b4] "m" (ps_b4), [alpha] "m" (alpha), [beta] "m" (beta), [c] "m" (c), @@ -2937,7 +2950,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2953,7 +2966,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2969,7 +2982,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -2985,7 +2998,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 data,cntx ); cij += nr_cur*cs_c0; - bj += nr_cur*cs_b0; + bj += nr_cur*cs_b0; n_left -= nr_cur; } @@ -3051,7 +3064,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 dim_t mr_cur = ( bli_is_not_edge_f( ii, m_iter, m_left ) ? mr : m_left ); - bli_sgemv_ex + bli_sgemv_ex ( BLIS_NO_TRANSPOSE, conjb, mr_cur, k0, alpha, ai_ii, rs_a0, cs_a0, bj, rs_b0, @@ -3059,7 +3072,7 @@ void bli_sgemmsup_rv_zen_asm_1x64n_avx512 ); cij_ii += mr_cur*rs_c0; ai_ii += ps_a0; - } + } } n_left -= nr_cur; } diff --git a/kernels/zen4/3/sup/d24x8/CMakeLists.txt b/kernels/zen4/3/sup/d24x8/CMakeLists.txt deleted file mode 100644 index 004a07c085..0000000000 --- a/kernels/zen4/3/sup/d24x8/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -##Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved.## - -add_library(zen4_3supd24x8 - OBJECT -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx1.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx2.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx3.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx4.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx5.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx6.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx7.c -${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemmsup_rv_zen4_asm_Mx8.c - ) - -target_compile_options(zen4_3supd24x8 PRIVATE /arch:AVX2 /arch:AVX512) -if(BUILD_SHARED_LIBS) - target_compile_definitions(zen4_3supd24x8 PUBLIC -DBLIS_IS_BUILDING_LIBRARY) -endif() diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c index d8806362e8..690404628e 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx1.c @@ -472,6 +472,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 vxorpd(zmm6, zmm6, zmm6) vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm10, zmm10, zmm10) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -481,6 +484,20 @@ void bli_dgemmsup_rv_zen4_asm_24x1 sub(imm( 1+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * to hold fma result. + * While even iterations uses zmm8, zmm9, zmm10 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -508,9 +525,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 3 @@ -532,9 +549,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 5 @@ -556,9 +573,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 7 @@ -576,9 +593,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -614,9 +631,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -637,9 +654,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -659,9 +676,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -677,9 +694,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -713,9 +730,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -735,9 +752,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -757,9 +774,9 @@ void bli_dgemmsup_rv_zen4_asm_24x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -775,13 +792,16 @@ void bli_dgemmsup_rv_zen4_asm_24x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) + vfmadd231pd( zmm5,zmm30,zmm10 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm8, zmm6, zmm6) + vaddpd(zmm9, zmm7, zmm7) + vaddpd(zmm10, zmm28, zmm28) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -1168,6 +1188,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) vxorpd(zmm7, zmm7, zmm7) + vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1177,6 +1199,19 @@ void bli_dgemmsup_rv_zen4_asm_16x1 sub(imm( 1+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7 + * to hold fma result. + * While even iterations uses zmm8, zmm9 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1200,8 +1235,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 3 @@ -1220,8 +1255,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 5 @@ -1240,8 +1275,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 7 @@ -1257,8 +1292,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -1290,8 +1325,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1308,8 +1343,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1326,8 +1361,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1341,8 +1376,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -1372,8 +1407,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1390,8 +1425,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1408,8 +1443,8 @@ void bli_dgemmsup_rv_zen4_asm_16x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1423,12 +1458,14 @@ void bli_dgemmsup_rv_zen4_asm_16x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm8 ) + vfmadd231pd( zmm4,zmm30,zmm9 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm8, zmm6, zmm6) + vaddpd(zmm9, zmm7, zmm7) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -1783,6 +1820,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1792,6 +1830,19 @@ void bli_dgemmsup_rv_zen4_asm_8x1 sub(imm( 1+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6 + * to hold fma result. + * While even iterations uses zmm7 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1811,7 +1862,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 3 @@ -1827,7 +1878,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 5 @@ -1843,7 +1894,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 7 @@ -1857,7 +1908,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -1884,7 +1935,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1898,7 +1949,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1912,7 +1963,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1924,7 +1975,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -1950,7 +2001,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1964,7 +2015,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1978,7 +2029,7 @@ void bli_dgemmsup_rv_zen4_asm_8x1 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -1990,11 +2041,12 @@ void bli_dgemmsup_rv_zen4_asm_8x1 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c index d8b5c73ad8..67a58c1b82 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx2.c @@ -476,6 +476,12 @@ void bli_dgemmsup_rv_zen4_asm_24x2 vxorpd(zmm8, zmm8, zmm8) vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm29, zmm29, zmm29) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -485,6 +491,21 @@ void bli_dgemmsup_rv_zen4_asm_24x2 sub(imm( 2+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -517,13 +538,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 @@ -549,13 +570,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 @@ -581,13 +602,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 @@ -609,13 +630,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -656,13 +677,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -687,13 +708,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -717,13 +738,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -743,13 +764,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -788,13 +809,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -818,13 +839,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -848,13 +869,13 @@ void bli_dgemmsup_rv_zen4_asm_24x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -874,17 +895,23 @@ void bli_dgemmsup_rv_zen4_asm_24x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -1292,6 +1319,10 @@ void bli_dgemmsup_rv_zen4_asm_16x2 vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) vxorpd(zmm9, zmm9, zmm9) + vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1301,6 +1332,21 @@ void bli_dgemmsup_rv_zen4_asm_16x2 sub(imm( 2+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, + * zmm8, zmm9 to hold fma result. + * While even iterations uses zmm10, zmm11, zmm12, zmm13 + * to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7,zmm8, + * zmm9. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1328,11 +1374,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 3 @@ -1354,11 +1400,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 5 @@ -1380,11 +1426,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 7 @@ -1403,11 +1449,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -1443,11 +1489,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1467,11 +1513,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1491,11 +1537,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1512,11 +1558,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -1550,11 +1596,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1574,11 +1620,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1598,11 +1644,11 @@ void bli_dgemmsup_rv_zen4_asm_16x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1619,15 +1665,19 @@ void bli_dgemmsup_rv_zen4_asm_16x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm4,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm4,zmm31,zmm13 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm10, zmm6, zmm6) + vaddpd(zmm11, zmm7, zmm7) + vaddpd(zmm12, zmm8, zmm8) + vaddpd(zmm13, zmm9, zmm9) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -1995,7 +2045,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -2005,6 +2057,19 @@ void bli_dgemmsup_rv_zen4_asm_8x2 sub(imm( 2+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm8 + * to hold fma result. + * While even iterations uses zmm7, zmm9 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm8. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -2027,9 +2092,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 3 @@ -2047,9 +2112,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 5 @@ -2067,9 +2132,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 7 @@ -2085,9 +2150,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -2117,9 +2182,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2135,9 +2200,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2153,9 +2218,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2169,9 +2234,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -2200,9 +2265,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2218,9 +2283,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2236,9 +2301,9 @@ void bli_dgemmsup_rv_zen4_asm_8x2 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2252,13 +2317,15 @@ void bli_dgemmsup_rv_zen4_asm_8x2 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c index a739183e98..ee6c3c573d 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx3.c @@ -480,6 +480,15 @@ void bli_dgemmsup_rv_zen4_asm_24x3 vxorpd(zmm10, zmm10, zmm10) vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -489,6 +498,21 @@ void bli_dgemmsup_rv_zen4_asm_24x3 sub(imm( 3+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26 to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21, zmm22 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm28, + * zmm8, zmm9, zmm29, zmm10, zmm11, zmm26. + */ label(.LOOP1) // ---------------------------------- iteration 1 @@ -525,17 +549,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 @@ -566,17 +590,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 @@ -606,17 +630,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 @@ -642,17 +666,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -697,17 +721,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 prefetchw0( mem(rdx, 128)) // prefetch C @@ -737,17 +761,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -775,17 +799,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -809,17 +833,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -862,17 +886,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -901,17 +925,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -939,17 +963,17 @@ void bli_dgemmsup_rv_zen4_asm_24x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -973,21 +997,30 @@ void bli_dgemmsup_rv_zen4_asm_24x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) - vfmadd231pd( zmm5,zmm30,zmm28 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm5,zmm30,zmm16 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) - vfmadd231pd( zmm5,zmm31,zmm29 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) + vfmadd231pd( zmm4,zmm31,zmm18 ) + vfmadd231pd( zmm5,zmm31,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) - vfmadd231pd( zmm5,zmm30,zmm26 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) + vfmadd231pd( zmm5,zmm30,zmm22 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm28, zmm28) + vaddpd(zmm17, zmm8, zmm8) + vaddpd(zmm18, zmm9, zmm9) + vaddpd(zmm19, zmm29, zmm29) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm26, zmm26) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -1410,6 +1443,12 @@ void bli_dgemmsup_rv_zen4_asm_16x3 vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) vxorpd(zmm11, zmm11, zmm11) + vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1419,6 +1458,22 @@ void bli_dgemmsup_rv_zen4_asm_16x3 sub(imm( 3+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, zmm8, + * zmm9, zmm10, zmm11 to hold fma result. + * While even iterations uses zmm12, zmm13, zmm14, zmm15, zmm16 + * zmm17 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, zmm8, + * zmm9, zmm10, zmm11. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1449,14 +1504,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 3 @@ -1482,14 +1537,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 5 @@ -1514,14 +1569,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 7 @@ -1543,14 +1598,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -1589,14 +1644,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1620,14 +1675,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1650,14 +1705,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1677,14 +1732,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -1721,14 +1776,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1752,14 +1807,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1782,14 +1837,14 @@ void bli_dgemmsup_rv_zen4_asm_16x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1809,18 +1864,24 @@ void bli_dgemmsup_rv_zen4_asm_16x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm12 ) + vfmadd231pd( zmm4,zmm30,zmm13 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm14 ) + vfmadd231pd( zmm4,zmm31,zmm15 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm12, zmm6, zmm6) + vaddpd(zmm13, zmm7, zmm7) + vaddpd(zmm14, zmm8, zmm8) + vaddpd(zmm15, zmm9, zmm9) + vaddpd(zmm16, zmm10, zmm10) + vaddpd(zmm17, zmm11, zmm11) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -2197,8 +2258,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -2208,6 +2272,21 @@ void bli_dgemmsup_rv_zen4_asm_8x3 sub(imm( 3+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm8, + * zmm10 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11 to hold fma + * result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm8, + * zmm10. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -2232,11 +2311,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 3 @@ -2257,11 +2336,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 5 @@ -2281,11 +2360,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 7 @@ -2303,11 +2382,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -2339,11 +2418,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2362,11 +2441,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2384,11 +2463,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2404,11 +2483,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -2439,11 +2518,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2462,11 +2541,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2484,11 +2563,11 @@ void bli_dgemmsup_rv_zen4_asm_8x3 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2504,15 +2583,18 @@ void bli_dgemmsup_rv_zen4_asm_8x3 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c index e5d70ae5fd..f8a3968f7b 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx4.c @@ -1529,6 +1529,15 @@ void bli_dgemmsup_rv_zen4_asm_16x4 vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) vxorpd(zmm13, zmm13, zmm13) + vxorpd(zmm14,zmm14, zmm14) + vxorpd(zmm15,zmm15, zmm15) + vxorpd(zmm16,zmm16, zmm16) + vxorpd(zmm17,zmm17, zmm17) + vxorpd(zmm18,zmm18, zmm18) + vxorpd(zmm19,zmm19, zmm19) + vxorpd(zmm20,zmm20, zmm20) + vxorpd(zmm21,zmm21, zmm21) + // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1538,6 +1547,22 @@ void bli_dgemmsup_rv_zen4_asm_16x4 sub(imm( 4+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13 + * to hold fma result. + * While even iterations uses zmm14, zmm15, zmm16, zmm17, zmm18 + * zmm19, zmm20, zmm21 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13 + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1571,17 +1596,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 3 @@ -1611,17 +1636,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 5 @@ -1649,17 +1674,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 7 @@ -1684,17 +1709,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -1736,17 +1761,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1774,17 +1799,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1810,17 +1835,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -1843,17 +1868,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -1893,17 +1918,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1931,17 +1956,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -1967,17 +1992,17 @@ void bli_dgemmsup_rv_zen4_asm_16x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -2000,21 +2025,29 @@ void bli_dgemmsup_rv_zen4_asm_16x4 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm4,zmm30,zmm15 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm4,zmm31,zmm17 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm14, zmm6, zmm6) + vaddpd(zmm15, zmm7, zmm7) + vaddpd(zmm16, zmm8, zmm8) + vaddpd(zmm17, zmm9, zmm9) + vaddpd(zmm18, zmm10, zmm10) + vaddpd(zmm19, zmm11, zmm11) + vaddpd(zmm20, zmm12, zmm12) + vaddpd(zmm21, zmm13, zmm13) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -2406,9 +2439,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -2418,6 +2455,21 @@ void bli_dgemmsup_rv_zen4_asm_8x4 sub(imm( 4+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, + * zmm8, zmm10, zmm12 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11, zmm12 + * to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, + * zmm8, zmm10, zmm12. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -2444,13 +2496,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 3 @@ -2474,13 +2526,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 5 @@ -2502,13 +2554,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 7 @@ -2528,13 +2580,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP1) // iterate again if i != 0. @@ -2568,13 +2620,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2596,13 +2648,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2622,13 +2674,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2646,13 +2698,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b sub(imm(1), rsi) // i -= 1 @@ -2685,13 +2737,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2713,13 +2765,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2739,13 +2791,13 @@ void bli_dgemmsup_rv_zen4_asm_8x4 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2763,17 +2815,21 @@ void bli_dgemmsup_rv_zen4_asm_8x4 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) + vaddpd(zmm13, zmm12, zmm12) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c index a41cbc4905..d014358c84 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx5.c @@ -1707,6 +1707,17 @@ void bli_dgemmsup_rv_zen4_asm_16x5 vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm14, zmm14, zmm14) vxorpd(zmm15, zmm15, zmm15) + vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) + vxorpd(zmm23, zmm23, zmm23) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm25, zmm25, zmm25) + // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1716,6 +1727,22 @@ void bli_dgemmsup_rv_zen4_asm_16x5 sub(imm( 5+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15 + * to hold fma result. + * While even iterations uses zmm16, zmm17, zmm18, zmm19, zmm20 + * zmm21, zmm22, zmm23, zmm24, zmm25 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1753,21 +1780,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 3 @@ -1801,21 +1828,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 5 @@ -1848,21 +1875,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 7 @@ -1891,21 +1918,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -1952,21 +1979,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -1998,21 +2025,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -2043,21 +2070,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -2084,21 +2111,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -2143,21 +2170,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -2189,21 +2216,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -2234,21 +2261,21 @@ void bli_dgemmsup_rv_zen4_asm_16x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -2275,26 +2302,36 @@ void bli_dgemmsup_rv_zen4_asm_16x5 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm16 ) + vfmadd231pd( zmm4,zmm30,zmm17 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm18 ) + vfmadd231pd( zmm4,zmm31,zmm19 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm20 ) + vfmadd231pd( zmm4,zmm30,zmm21 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm22 ) + vfmadd231pd( zmm4,zmm31,zmm23 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm24 ) + vfmadd231pd( zmm4,zmm30,zmm25 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm16, zmm6, zmm6) + vaddpd(zmm17, zmm7, zmm7) + vaddpd(zmm18, zmm8, zmm8) + vaddpd(zmm19, zmm9, zmm9) + vaddpd(zmm20, zmm10, zmm10) + vaddpd(zmm21, zmm11, zmm11) + vaddpd(zmm22, zmm12, zmm12) + vaddpd(zmm23, zmm13, zmm13) + vaddpd(zmm24, zmm14, zmm14) + vaddpd(zmm25, zmm15, zmm15) label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -2715,10 +2752,15 @@ void bli_dgemmsup_rv_zen4_asm_8x5 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -2728,6 +2770,21 @@ void bli_dgemmsup_rv_zen4_asm_8x5 sub(imm( 5+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, + * zmm8, zmm10, zmm12, zmm14 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11, zmm13, zmm15 + * to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, + * zmm8, zmm10, zmm12, zmm14 + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -2757,16 +2814,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 3 @@ -2793,16 +2850,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 5 @@ -2828,16 +2885,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 7 @@ -2860,16 +2917,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -2907,16 +2964,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2941,16 +2998,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -2974,16 +3031,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3004,16 +3061,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -3050,16 +3107,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3084,16 +3141,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3117,16 +3174,16 @@ void bli_dgemmsup_rv_zen4_asm_8x5 add( r10,rax ) // a += cs_a vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3147,21 +3204,26 @@ void bli_dgemmsup_rv_zen4_asm_8x5 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) + vaddpd(zmm13, zmm12, zmm12) + vaddpd(zmm15, zmm14, zmm14) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c index fe638c320f..db9ba7cae2 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx6.c @@ -1828,6 +1828,18 @@ void bli_dgemmsup_rv_zen4_asm_16x6 vxorpd(zmm15, zmm15, zmm15) vxorpd(zmm16, zmm16, zmm16) vxorpd(zmm17, zmm17, zmm17) + vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) + vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) + vxorpd(zmm22, zmm22, zmm22) + vxorpd(zmm23, zmm23, zmm23) + vxorpd(zmm24, zmm24, zmm24) + vxorpd(zmm25, zmm25, zmm25) + vxorpd(zmm26, zmm26, zmm26) + vxorpd(zmm27, zmm27, zmm27) + vxorpd(zmm28, zmm28, zmm28) + vxorpd(zmm29, zmm29, zmm29) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -1837,6 +1849,23 @@ void bli_dgemmsup_rv_zen4_asm_16x6 sub(imm( 6+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15 + * zmm16, zmm17 to hold fma result. + * While even iterations uses zmm18, zmm19, zmm20 + * zmm21, zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29 + * to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, zmm7, + * zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, zmm16, zmm17 + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -1877,24 +1906,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 3 @@ -1931,24 +1960,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 5 @@ -1985,24 +2014,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 7 @@ -2034,24 +2063,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -2101,24 +2130,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -2153,24 +2182,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -2205,24 +2234,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -2252,24 +2281,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -2317,24 +2346,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 ) // load A @@ -2369,24 +2398,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 ) // load A @@ -2421,24 +2450,24 @@ void bli_dgemmsup_rv_zen4_asm_16x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 ) // load A @@ -2468,29 +2497,42 @@ void bli_dgemmsup_rv_zen4_asm_16x6 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) - vfmadd231pd( zmm4,zmm30,zmm7 ) + vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm4,zmm30,zmm19 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) - vfmadd231pd( zmm4,zmm31,zmm9 ) + vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm4,zmm31,zmm21 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) - vfmadd231pd( zmm4,zmm30,zmm11 ) + vfmadd231pd( zmm3,zmm30,zmm22 ) + vfmadd231pd( zmm4,zmm30,zmm23 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) - vfmadd231pd( zmm4,zmm31,zmm13 ) + vfmadd231pd( zmm3,zmm31,zmm24 ) + vfmadd231pd( zmm4,zmm31,zmm25 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) - vfmadd231pd( zmm4,zmm30,zmm15 ) + vfmadd231pd( zmm3,zmm30,zmm26 ) + vfmadd231pd( zmm4,zmm30,zmm27 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) - vfmadd231pd( zmm4,zmm31,zmm17 ) + vfmadd231pd( zmm3,zmm31,zmm28 ) + vfmadd231pd( zmm4,zmm31,zmm29 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm18, zmm6, zmm6) + vaddpd(zmm19, zmm7, zmm7) + vaddpd(zmm20, zmm8, zmm8) + vaddpd(zmm21, zmm9, zmm9) + vaddpd(zmm22, zmm10, zmm10) + vaddpd(zmm23, zmm11, zmm11) + vaddpd(zmm24, zmm12, zmm12) + vaddpd(zmm25, zmm13, zmm13) + vaddpd(zmm26, zmm14, zmm14) + vaddpd(zmm27, zmm15, zmm15) + vaddpd(zmm28, zmm16, zmm16) + vaddpd(zmm29, zmm17, zmm17) + label(.TAIL) mov(var(k_left), rsi) // i = k_left @@ -2924,11 +2966,17 @@ void bli_dgemmsup_rv_zen4_asm_8x6 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -2938,6 +2986,21 @@ void bli_dgemmsup_rv_zen4_asm_8x6 sub(imm( 6+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11 + * zmm13, zmm15, zmm17 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16. + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -2969,18 +3032,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 3 @@ -3009,18 +3072,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 5 @@ -3049,18 +3112,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 7 @@ -3085,18 +3148,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -3136,18 +3199,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3174,18 +3237,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3212,18 +3275,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3246,18 +3309,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -3296,18 +3359,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3334,18 +3397,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3372,18 +3435,18 @@ void bli_dgemmsup_rv_zen4_asm_8x6 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3406,23 +3469,29 @@ void bli_dgemmsup_rv_zen4_asm_8x6 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) + vaddpd(zmm13, zmm12, zmm12) + vaddpd(zmm15, zmm14, zmm14) + vaddpd(zmm17, zmm16, zmm16) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c index 610871ab2e..9e4194c118 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx7.c @@ -3128,12 +3128,19 @@ void bli_dgemmsup_rv_zen4_asm_8x7 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -3143,6 +3150,21 @@ void bli_dgemmsup_rv_zen4_asm_8x7 sub(imm( 7+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16, zmm18 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11 + * zmm13, zmm15, zmm17, zmm19 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16, zmm18 + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -3176,20 +3198,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 3 @@ -3220,20 +3242,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 5 @@ -3264,20 +3286,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 7 @@ -3305,20 +3327,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -3360,20 +3382,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3402,20 +3424,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3444,20 +3466,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3483,20 +3505,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -3537,20 +3559,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3579,20 +3601,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3621,20 +3643,20 @@ void bli_dgemmsup_rv_zen4_asm_8x7 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3660,25 +3682,32 @@ void bli_dgemmsup_rv_zen4_asm_8x7 // ---------------------------------- iteration 8 vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) + vaddpd(zmm13, zmm12, zmm12) + vaddpd(zmm15, zmm14, zmm14) + vaddpd(zmm17, zmm16, zmm16) + vaddpd(zmm19, zmm18, zmm18) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c index 8cf46b43c5..065cbd5bb6 100644 --- a/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c +++ b/kernels/zen4/3/sup/d24x8/bli_dgemmsup_rv_zen4_asm_Mx8.c @@ -3284,13 +3284,21 @@ void bli_dgemmsup_rv_zen4_asm_8x8 // zero out all accumulation registers vxorpd(zmm6, zmm6, zmm6) + vxorpd(zmm7, zmm7, zmm7) vxorpd(zmm8, zmm8, zmm8) + vxorpd(zmm9, zmm9, zmm9) vxorpd(zmm10, zmm10, zmm10) + vxorpd(zmm11, zmm11, zmm11) vxorpd(zmm12, zmm12, zmm12) + vxorpd(zmm13, zmm13, zmm13) vxorpd(zmm14, zmm14, zmm14) + vxorpd(zmm15, zmm15, zmm15) vxorpd(zmm16, zmm16, zmm16) + vxorpd(zmm17, zmm17, zmm17) vxorpd(zmm18, zmm18, zmm18) + vxorpd(zmm19, zmm19, zmm19) vxorpd(zmm20, zmm20, zmm20) + vxorpd(zmm21, zmm21, zmm21) // K is unrolled by 8 to facilitate prefetch of B // Assuming B to be col-stored, for each iteration of K, @@ -3300,6 +3308,21 @@ void bli_dgemmsup_rv_zen4_asm_8x8 sub(imm( 8+TAIL_NITER), rsi) // i -= NR + TAIL_NITER jle(.PREFETCHLOOP) // jump if i <= 0 + /** + * This edge kernel uses two separate vector register bank + * to hold fma result. + * Once the K loop is completed these two vector register banks + * are added together and final result is available in one + * register bank. + * Here odd iterations uses vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16, zmm18, zmm20 to hold fma result. + * While even iterations uses zmm7, zmm9, zmm11 + * zmm13, zmm15, zmm17, zmm19, zmm21 to hold fma result. + * At the end of K loop, these two banks are added together and + * final result is available in vector register zmm6, + * zmm8, zmm10, zmm12, zmm14, zmm16, zmm18, zmm20 + */ + label(.LOOP1) // ---------------------------------- iteration 1 @@ -3335,22 +3358,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 3 @@ -3383,22 +3406,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 5 @@ -3431,22 +3454,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 7 @@ -3477,22 +3500,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer to b_next += 8*rs_b dec(rsi) // i -= 1 @@ -3536,22 +3559,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3582,22 +3605,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3628,22 +3651,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3672,22 +3695,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) lea(mem(rdx, rdi, 1), rdx) // C += cs_c lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // second pointer of b_next += 8*rs_b @@ -3730,22 +3753,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 3 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3776,22 +3799,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r11,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 5 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3822,22 +3845,22 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r9,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) // ---------------------------------- iteration 7 vmovupd( mem(rax),zmm3 MASK_KZ(2) ) // load A // Load A with mask and zero hint @@ -3866,27 +3889,35 @@ void bli_dgemmsup_rv_zen4_asm_8x8 prefetch( 0,mem(r15,r13,1) ) // prefetch B vbroadcastsd( mem(rbx),zmm30 ) vbroadcastsd( mem(rbx,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm6 ) + vfmadd231pd( zmm3,zmm30,zmm7 ) vbroadcastsd( mem(rbx,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm8 ) + vfmadd231pd( zmm3,zmm31,zmm9 ) vbroadcastsd( mem(rbx,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm10 ) + vfmadd231pd( zmm3,zmm30,zmm11 ) vbroadcastsd( mem(r12),zmm30 ) add( r8,rbx ) // b += rs_b - vfmadd231pd( zmm3,zmm31,zmm12 ) + vfmadd231pd( zmm3,zmm31,zmm13 ) vbroadcastsd( mem(r12,r9,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm14 ) + vfmadd231pd( zmm3,zmm30,zmm15 ) vbroadcastsd( mem(r12,r9,2),zmm30 ) - vfmadd231pd( zmm3,zmm31,zmm16 ) + vfmadd231pd( zmm3,zmm31,zmm17 ) vbroadcastsd( mem(r12,r13,1),zmm31 ) - vfmadd231pd( zmm3,zmm30,zmm18 ) + vfmadd231pd( zmm3,zmm30,zmm19 ) add( r8,r12 ) // second pointer of b += rs_b - vfmadd231pd( zmm3,zmm31,zmm20 ) + vfmadd231pd( zmm3,zmm31,zmm21 ) lea(mem(r11,r8,8), r11) // b_next += 8*rs_b lea(mem(r15,r8,8), r15) // Second pointer of b_next += 8*rs_b dec(rsi) // i -= 1 jnz(.LOOP3) // iterate again if i != 0. + vaddpd(zmm7, zmm6, zmm6) + vaddpd(zmm9, zmm8, zmm8) + vaddpd(zmm11, zmm10, zmm10) + vaddpd(zmm13, zmm12, zmm12) + vaddpd(zmm15, zmm14, zmm14) + vaddpd(zmm17, zmm16, zmm16) + vaddpd(zmm19, zmm18, zmm18) + vaddpd(zmm21, zmm20, zmm20) label(.TAIL) mov(var(k_left), rsi) // i = k_left diff --git a/kernels/zen4/CMakeLists.txt b/kernels/zen4/CMakeLists.txt deleted file mode 100644 index 7878918053..0000000000 --- a/kernels/zen4/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -##Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.## -remove_definitions(/arch:AVX2) - -add_subdirectory(1) -add_subdirectory(1m) -add_subdirectory(3) -add_subdirectory(aocl_smart) \ No newline at end of file diff --git a/kernels/zen4/aocl_smart/CMakeLists.txt b/kernels/zen4/aocl_smart/CMakeLists.txt deleted file mode 100644 index ef10975d24..0000000000 --- a/kernels/zen4/aocl_smart/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -##Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_aocl_smart.c - ) diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h index 701e2ecb49..82872ac942 100644 --- a/kernels/zen4/bli_kernels_zen4.h +++ b/kernels/zen4/bli_kernels_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,8 +39,9 @@ AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) // scalv (AVX512 intrinsics) -SCALV_KER_PROT( float, s, scalv_zen_int_avx512 ) -SCALV_KER_PROT( double, d, scalv_zen_int_avx512 ) +SCALV_KER_PROT( float, s, scalv_zen_int_avx512 ) +SCALV_KER_PROT( double, d, scalv_zen_int_avx512 ) +SCALV_KER_PROT( dcomplex, z, dscalv_zen_int_avx512) // ZDSCAL kernel // dotv (intrinsics) DOTV_KER_PROT( float, s, dotv_zen_int_avx512 ) @@ -54,6 +55,8 @@ GEMMTRSM_UKR_PROT( double, d, gemmtrsm_l_zen_asm_16x14) GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_zen_asm_16x14) GEMMTRSM_UKR_PROT( double, d, gemmtrsm_l_zen4_asm_8x24) GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_zen4_asm_8x24) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_l_zen4_asm_4x12) +GEMMTRSM_UKR_PROT( dcomplex, z, gemmtrsm_u_zen4_asm_4x12) //packing kernels PACKM_KER_PROT( double, d, packm_zen4_asm_16xk ) @@ -68,6 +71,8 @@ PACKM_KER_PROT( dcomplex, z, packm_zen4_asm_4xk ) GEMM_UKR_PROT( double, d, gemm_zen4_asm_32x6 ) GEMM_UKR_PROT( double, d, gemm_zen4_asm_8x24 ) GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_12x4 ) +GEMM_UKR_PROT( dcomplex, z, gemm_zen4_asm_4x12 ) + //sgemm rv sup GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x64m_avx512 ) @@ -199,6 +204,18 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_2x3 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_2x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_cv_zen4_asm_2x1 ) +err_t bli_dgemm_24x8_avx512_k1_nn + ( + dim_t m, + dim_t n, + dim_t k, + double* alpha, + double* a, const inc_t lda, + double* b, const inc_t ldb, + double* beta, + double* c, const inc_t ldc + ); + // threshold functions bool bli_cntx_gemmsup_thresh_is_met_zen4 ( @@ -207,3 +224,6 @@ bool bli_cntx_gemmsup_thresh_is_met_zen4 obj_t* c, cntx_t* cntx ); + +// function for resetting zmm registers after L3 apis +void bli_zero_zmm(); diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index 592af7f042..d5fa298c2d 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h index f3875647eb..484c2930eb 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,8 @@ // Disable BF16 kernel in cases where compilers support other avx 512 // features except BF16 ISA. -#if defined( BLIS_GCC ) && ( __GNUC__ < 10 ) +#if ( defined( BLIS_GCC ) && ( ( __GNUC__ < 11 ) || \ + ( ( __GNUC__ == 11 ) && ( __GNUC_MINOR__ < 2 ) ) ) ) #define LPGEMM_BF16_NOT_SUPPORTED #endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index e3e3bc2869..26f45c5101 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 01b59d38cf..f0d58752e4 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index c95c0090ae..36bc91d78f 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c new file mode 100644 index 0000000000..b928338f30 --- /dev/null +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packa_bf16_amd256vnni.c @@ -0,0 +1,1493 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#ifdef BLIS_ADDON_LPGEMM + + +#define UNPACKLO_EPI16 \ + b_reg[0] = _mm256_unpacklo_epi16(a_reg[0], a_reg[1]); \ + b_reg[1] = _mm256_unpacklo_epi16(a_reg[2], a_reg[3]); \ + b_reg[2] = _mm256_unpacklo_epi16(a_reg[4], a_reg[5]); \ + b_reg[3] = _mm256_unpacklo_epi16(a_reg[6], a_reg[7]); \ + b_reg[4] = _mm256_unpacklo_epi16(a_reg[8], a_reg[9]); \ + b_reg[5] = _mm256_unpacklo_epi16(a_reg[10], a_reg[11]); \ + b_reg[6] = _mm256_unpacklo_epi16(a_reg[12], a_reg[13]); \ + b_reg[7] = _mm256_unpacklo_epi16(a_reg[14], a_reg[15]); + +#define UNPACKHI_EPI16 \ + b_reg[8] = _mm256_unpackhi_epi16(a_reg[0], a_reg[1]); \ + b_reg[9] = _mm256_unpackhi_epi16(a_reg[2], a_reg[3]); \ + b_reg[10] = _mm256_unpackhi_epi16(a_reg[4], a_reg[5]); \ + b_reg[11] = _mm256_unpackhi_epi16(a_reg[6], a_reg[7]); \ + b_reg[12] = _mm256_unpackhi_epi16(a_reg[8], a_reg[9]); \ + b_reg[13] = _mm256_unpackhi_epi16(a_reg[10], a_reg[11]); \ + b_reg[14] = _mm256_unpackhi_epi16(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm256_unpackhi_epi16(a_reg[14], a_reg[15]); + +#define UNPACKLO_EPI32 \ + a_reg[0] = _mm256_unpacklo_epi32(b_reg[0], b_reg[1]); \ + a_reg[1] = _mm256_unpacklo_epi32(b_reg[2], b_reg[3]); \ + a_reg[2] = _mm256_unpacklo_epi32(b_reg[4], b_reg[5]); \ + a_reg[3] = _mm256_unpacklo_epi32(b_reg[6], b_reg[7]); \ +\ + a_reg[8] = _mm256_unpacklo_epi32(b_reg[8], b_reg[9]); \ + a_reg[9] = _mm256_unpacklo_epi32(b_reg[10], b_reg[11]); \ + a_reg[10] = _mm256_unpacklo_epi32(b_reg[12], b_reg[13]); \ + a_reg[11] = _mm256_unpacklo_epi32(b_reg[14], b_reg[15]); + +#define UNPACKHI_EPI32 \ + a_reg[4] = _mm256_unpackhi_epi32(b_reg[0], b_reg[1]); \ + a_reg[5] = _mm256_unpackhi_epi32(b_reg[2], b_reg[3]); \ + a_reg[6] = _mm256_unpackhi_epi32(b_reg[4], b_reg[5]); \ + a_reg[7] = _mm256_unpackhi_epi32(b_reg[6], b_reg[7]); \ +\ + a_reg[12] = _mm256_unpackhi_epi32(b_reg[8], b_reg[9]); \ + a_reg[13] = _mm256_unpackhi_epi32(b_reg[10], b_reg[11]); \ + a_reg[14] = _mm256_unpackhi_epi32(b_reg[12], b_reg[13]); \ + a_reg[15] = _mm256_unpackhi_epi32(b_reg[14], b_reg[15]); + +#define UNPACKLO_EPI64 \ + b_reg[0] = _mm256_unpacklo_epi64(a_reg[0], a_reg[1]); \ + b_reg[1] = _mm256_unpacklo_epi64(a_reg[2], a_reg[3]); \ + b_reg[2] = _mm256_unpacklo_epi64(a_reg[4], a_reg[5]); \ + b_reg[3] = _mm256_unpacklo_epi64(a_reg[6], a_reg[7]); \ +\ + b_reg[8] = _mm256_unpacklo_epi64(a_reg[8], a_reg[9]); \ + b_reg[9] = _mm256_unpacklo_epi64(a_reg[10], a_reg[11]); \ + b_reg[10] = _mm256_unpacklo_epi64(a_reg[12], a_reg[13]); \ + b_reg[11] = _mm256_unpacklo_epi64(a_reg[14], a_reg[15]); + +#define UNPACKHI_EPI64 \ + b_reg[4] = _mm256_unpackhi_epi64(a_reg[0], a_reg[1]); \ + b_reg[5] = _mm256_unpackhi_epi64(a_reg[2], a_reg[3]); \ + b_reg[6] = _mm256_unpackhi_epi64(a_reg[4], a_reg[5]); \ + b_reg[7] = _mm256_unpackhi_epi64(a_reg[6], a_reg[7]); \ +\ + b_reg[12] = _mm256_unpackhi_epi64(a_reg[8], a_reg[9]); \ + b_reg[13] = _mm256_unpackhi_epi64(a_reg[10], a_reg[11]); \ + b_reg[14] = _mm256_unpackhi_epi64(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm256_unpackhi_epi64(a_reg[14], a_reg[15]); + +#define SHUFFLE_64x2 \ + a_reg[0] = _mm256_shuffle_i64x2(b_reg[0], b_reg[1], 0x0); \ + a_reg[1] = _mm256_shuffle_i64x2(b_reg[0], b_reg[1], 0x3); \ + a_reg[2] = _mm256_shuffle_i64x2(b_reg[2], b_reg[3], 0x0); \ + a_reg[3] = _mm256_shuffle_i64x2(b_reg[2], b_reg[3], 0x3); \ +\ + a_reg[4] = _mm256_shuffle_i64x2(b_reg[4], b_reg[5], 0x0); \ + a_reg[5] = _mm256_shuffle_i64x2(b_reg[4], b_reg[5], 0x3); \ + a_reg[6] = _mm256_shuffle_i64x2(b_reg[6], b_reg[7], 0x0); \ + a_reg[7] = _mm256_shuffle_i64x2(b_reg[6], b_reg[7], 0x3); \ +\ + a_reg[8] = _mm256_shuffle_i64x2(b_reg[8], b_reg[9], 0x0); \ + a_reg[9] = _mm256_shuffle_i64x2(b_reg[8], b_reg[9], 0x3); \ + a_reg[10] = _mm256_shuffle_i64x2(b_reg[10], b_reg[11], 0x0); \ + a_reg[11] = _mm256_shuffle_i64x2(b_reg[10], b_reg[11], 0x3); \ +\ + a_reg[12] = _mm256_shuffle_i64x2(b_reg[12], b_reg[13], 0x0); \ + a_reg[13] = _mm256_shuffle_i64x2(b_reg[12], b_reg[13], 0x3); \ + a_reg[14] = _mm256_shuffle_i64x2(b_reg[14], b_reg[15], 0x0); \ + a_reg[15] = _mm256_shuffle_i64x2(b_reg[14], b_reg[15], 0x3); + +#define MASKED_STORE_EPI64(mask) \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+0) * KC + kr ), mask, a_reg[0]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+1) * KC + kr ), mask, a_reg[4]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+2) * KC + kr ), mask, a_reg[2]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+3) * KC + kr ), mask, a_reg[6]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+4) * KC + kr ), mask, a_reg[8]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+5) * KC + kr ), mask, a_reg[12]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+6) * KC + kr ), mask, a_reg[10]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+7) * KC + kr ), mask, a_reg[14]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+8) * KC + kr ), mask, a_reg[1]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+9) * KC + kr ), mask, a_reg[5]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+10) * KC + kr ), mask, a_reg[3]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+11) * KC + kr ), mask, a_reg[7]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+12) * KC + kr ), mask, a_reg[9]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+13) * KC + kr ), mask, a_reg[13]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+14) * KC + kr ), mask, a_reg[11]); \ + _mm256_mask_storeu_epi64((pack_a_buffer + (ic+15) * KC + kr ), mask, a_reg[15]); + +#define MASKED_STORE_EPI32(mask) \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+0) * KC + kr ), mask, a_reg[0]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+1) * KC + kr ), mask, a_reg[4]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+2) * KC + kr ), mask, a_reg[2]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+3) * KC + kr ), mask, a_reg[6]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+4) * KC + kr ), mask, a_reg[8]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+5) * KC + kr ), mask, a_reg[12]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+6) * KC + kr ), mask, a_reg[10]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+7) * KC + kr ), mask, a_reg[14]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+8) * KC + kr ), mask, a_reg[1]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+9) * KC + kr ), mask, a_reg[5]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+10) * KC + kr ), mask, a_reg[3]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+11) * KC + kr ), mask, a_reg[7]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+12) * KC + kr ), mask, a_reg[9]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+13) * KC + kr ), mask, a_reg[13]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+14) * KC + kr ), mask, a_reg[11]); \ + _mm256_mask_storeu_epi32((pack_a_buffer + (ic+15) * KC + kr ), mask, a_reg[15]); + +#define MASKED_STORE_EPI16(mask) \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+0) * KC + kr ), mask, a_reg[0]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+1) * KC + kr ), mask, a_reg[4]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+2) * KC + kr ), mask, a_reg[2]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+3) * KC + kr ), mask, a_reg[6]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+4) * KC + kr ), mask, a_reg[8]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+5) * KC + kr ), mask, a_reg[12]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+6) * KC + kr ), mask, a_reg[10]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+7) * KC + kr ), mask, a_reg[14]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+8) * KC + kr ), mask, a_reg[1]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+9) * KC + kr ), mask, a_reg[5]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+10) * KC + kr ), mask, a_reg[3]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+11) * KC + kr ), mask, a_reg[7]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+12) * KC + kr ), mask, a_reg[9]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+13) * KC + kr ), mask, a_reg[13]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+14) * KC + kr ), mask, a_reg[11]); \ + _mm256_mask_storeu_epi16((pack_a_buffer + (ic+15) * KC + kr ), mask, a_reg[15]); + +#define MASKED_LOAD_32_ROWS_AVX512( mask ) \ + a_reg[0] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[1] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[2] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[3] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[4] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[5] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[6] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[7] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[8] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 8 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[9] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 9 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[10] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 10 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[11] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 11 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[12] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 12 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[13] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 13 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[14] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 14 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[15] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 15 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[16] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 16 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[17] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 17 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[18] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 18 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[19] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 19 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[20] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 20 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[21] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 21 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[22] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 22 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[23] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 23 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[24] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 24 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[25] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 25 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[26] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 26 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[27] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 27 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[28] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 28 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[29] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 29 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[30] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 30 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[31] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 31 ) * rs_a ) + ( kr * cs_a )); + +#define MASKED_STORE_32_ROWS_AVX512( mask ) \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, mask, a_reg[0] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, mask, a_reg[1] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr, mask, a_reg[2] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr, mask, a_reg[3] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr, mask, a_reg[4] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr, mask, a_reg[5] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr, mask, a_reg[6] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr, mask, a_reg[7] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 8 ) * KC ) + kr, mask, a_reg[8] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 9 ) * KC ) + kr, mask, a_reg[9] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 10 ) * KC ) + kr, mask, a_reg[10] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 11 ) * KC ) + kr, mask, a_reg[11] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 12 ) * KC ) + kr, mask, a_reg[12] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 13 ) * KC ) + kr, mask, a_reg[13] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 14 ) * KC ) + kr, mask, a_reg[14] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 15 ) * KC ) + kr, mask, a_reg[15] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 16 ) * KC ) + kr, mask, a_reg[16] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 17 ) * KC ) + kr, mask, a_reg[17] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 18 ) * KC ) + kr, mask, a_reg[18] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 19 ) * KC ) + kr, mask, a_reg[19] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 20 ) * KC ) + kr, mask, a_reg[20] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 21 ) * KC ) + kr, mask, a_reg[21] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 22 ) * KC ) + kr, mask, a_reg[22] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 23 ) * KC ) + kr, mask, a_reg[23] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 24 ) * KC ) + kr, mask, a_reg[24] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 25 ) * KC ) + kr, mask, a_reg[25] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 26 ) * KC ) + kr, mask, a_reg[26] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 27 ) * KC ) + kr, mask, a_reg[27] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 28 ) * KC ) + kr, mask, a_reg[28] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 29 ) * KC ) + kr, mask, a_reg[29] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 30 ) * KC ) + kr, mask, a_reg[30] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 31 ) * KC ) + kr, mask, a_reg[31] ); + + +#define MASKED_LOAD_16_ROWS_AVX512( mask ) \ + a_reg[0] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[1] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[2] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[3] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[4] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[5] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[6] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[7] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[8] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 8 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[9] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 9 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[10] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 10 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[11] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 11 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[12] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 12 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[13] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 13 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[14] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 14 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[15] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 15 ) * rs_a ) + ( kr * cs_a )); + +#define MASKED_STORE_16_ROWS_AVX512( mask ) \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, mask, a_reg[0] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, mask, a_reg[1] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr, mask, a_reg[2] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr, mask, a_reg[3] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr, mask, a_reg[4] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr, mask, a_reg[5] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr, mask, a_reg[6] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr, mask, a_reg[7] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 8 ) * KC ) + kr, mask, a_reg[8] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 9 ) * KC ) + kr, mask, a_reg[9] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 10 ) * KC ) + kr, mask, a_reg[10] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 11 ) * KC ) + kr, mask, a_reg[11] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 12 ) * KC ) + kr, mask, a_reg[12] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 13 ) * KC ) + kr, mask, a_reg[13] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 14 ) * KC ) + kr, mask, a_reg[14] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 15 ) * KC ) + kr, mask, a_reg[15] ); + + +#define MASKED_LOAD_8_ROWS_AVX512( mask ) \ + a_reg[0] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[1] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[2] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[3] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[4] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[5] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[6] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[7] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a )); + +#define MASKED_STORE_8_ROWS_AVX512( mask ) \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, mask, a_reg[0] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, mask, a_reg[1] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr, mask, a_reg[2] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr, mask, a_reg[3] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr, mask, a_reg[4] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr, mask, a_reg[5] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr, mask, a_reg[6] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr, mask, a_reg[7] ); + + +#define MASKED_LOAD_4_ROWS_AVX512( mask ) \ + a_reg[0] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[1] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[2] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a )); \ + a_reg[3] = _mm512_maskz_loadu_epi16( mask, a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a )); + +#define MASKED_STORE_4_ROWS_AVX512( mask ) \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, mask, a_reg[0] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, mask, a_reg[1] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr, mask, a_reg[2] ); \ + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr, mask, a_reg[3] ); + +void packa_mr16_bf16bf16f32of32_row_major + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); + +void packa_mr16_bf16bf16f32of32_col_major + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ); + +void packa_mr16_bf16bf16f32of32 + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + if( cs_a == 1 ) + { + packa_mr16_bf16bf16f32of32_row_major + ( pack_a_buffer, a, rs_a, cs_a, MC, KC, rs_p, cs_p); + } + else + { + packa_mr16_bf16bf16f32of32_col_major + ( pack_a_buffer, a, rs_a, cs_a, MC, KC, rs_p, cs_p); + } +} + +void packa_mr16_bf16bf16f32of32_row_major + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + dim_t MR = 32; + + __m512i a_reg[32]; + + dim_t ic = 0, kr = 0; + + for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR ) + { + for( kr = 0; ( kr + 32 - 1) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_loadu_si512( a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[2] = _mm512_loadu_si512( a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[3] = _mm512_loadu_si512( a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[4] = _mm512_loadu_si512( a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[5] = _mm512_loadu_si512( a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[6] = _mm512_loadu_si512( a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[7] = _mm512_loadu_si512( a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[8] = _mm512_loadu_si512( a + ( ( ic + 8 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[9] = _mm512_loadu_si512( a + ( ( ic + 9 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[10] = _mm512_loadu_si512( a + ( ( ic + 10 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[11] = _mm512_loadu_si512( a + ( ( ic + 11 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[12] = _mm512_loadu_si512( a + ( ( ic + 12 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[13] = _mm512_loadu_si512( a + ( ( ic + 13 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[14] = _mm512_loadu_si512( a + ( ( ic + 14 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[15] = _mm512_loadu_si512( a + ( ( ic + 15 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[16] = _mm512_loadu_si512( a + ( ( ic + 16 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[17] = _mm512_loadu_si512( a + ( ( ic + 17 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[18] = _mm512_loadu_si512( a + ( ( ic + 18 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[19] = _mm512_loadu_si512( a + ( ( ic + 19 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[20] = _mm512_loadu_si512( a + ( ( ic + 20 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[21] = _mm512_loadu_si512( a + ( ( ic + 21 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[22] = _mm512_loadu_si512( a + ( ( ic + 22 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[23] = _mm512_loadu_si512( a + ( ( ic + 23 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[24] = _mm512_loadu_si512( a + ( ( ic + 24 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[25] = _mm512_loadu_si512( a + ( ( ic + 25 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[26] = _mm512_loadu_si512( a + ( ( ic + 26 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[27] = _mm512_loadu_si512( a + ( ( ic + 27 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[28] = _mm512_loadu_si512( a + ( ( ic + 28 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[29] = _mm512_loadu_si512( a + ( ( ic + 29 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[30] = _mm512_loadu_si512( a + ( ( ic + 30 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[31] = _mm512_loadu_si512( a + ( ( ic + 31 ) * rs_a ) + ( kr * cs_a ) ); + + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr , a_reg[1] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr , a_reg[2] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr , a_reg[3] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr , a_reg[4] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr , a_reg[5] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr , a_reg[6] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr , a_reg[7] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 8 ) * KC ) + kr , a_reg[8] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 9 ) * KC ) + kr , a_reg[9] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 10 ) * KC ) + kr , a_reg[10] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 11 ) * KC ) + kr , a_reg[11] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 12 ) * KC ) + kr , a_reg[12] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 13 ) * KC ) + kr , a_reg[13] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 14 ) * KC ) + kr , a_reg[14] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 15 ) * KC ) + kr , a_reg[15] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 16 ) * KC ) + kr , a_reg[16] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 17 ) * KC ) + kr , a_reg[17] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 18 ) * KC ) + kr , a_reg[18] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 19 ) * KC ) + kr , a_reg[19] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 20 ) * KC ) + kr , a_reg[20] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 21 ) * KC ) + kr , a_reg[21] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 22 ) * KC ) + kr , a_reg[22] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 23 ) * KC ) + kr , a_reg[23] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 24 ) * KC ) + kr , a_reg[24] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 25 ) * KC ) + kr , a_reg[25] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 26 ) * KC ) + kr , a_reg[26] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 27 ) * KC ) + kr , a_reg[27] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 28 ) * KC ) + kr , a_reg[28] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 29 ) * KC ) + kr , a_reg[29] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 30 ) * KC ) + kr , a_reg[30] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 31 ) * KC ) + kr , a_reg[31] ); + } + for( ; ( kr + 15 ) < KC; kr += 16 ) + { + MASKED_LOAD_32_ROWS_AVX512( 0xFFFF ) + + MASKED_STORE_32_ROWS_AVX512( 0xFFFF ) + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + MASKED_LOAD_32_ROWS_AVX512( 0xFF ) + + MASKED_STORE_32_ROWS_AVX512( 0xFF ) + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + MASKED_LOAD_32_ROWS_AVX512( 0xF ) + + MASKED_STORE_32_ROWS_AVX512( 0xF ) + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + MASKED_LOAD_32_ROWS_AVX512( 0x3 ) + + MASKED_STORE_32_ROWS_AVX512( 0x3 ) + } + for( ; ( kr ) < KC; kr += 1 ) + { + MASKED_LOAD_32_ROWS_AVX512( 0x1 ) + + MASKED_STORE_32_ROWS_AVX512( 0x1 ) + } + } + for( ; ( ic + 16 - 1 ) < MC; ic += 16 ) + { + for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_loadu_si512( a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[2] = _mm512_loadu_si512( a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[3] = _mm512_loadu_si512( a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[4] = _mm512_loadu_si512( a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[5] = _mm512_loadu_si512( a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[6] = _mm512_loadu_si512( a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[7] = _mm512_loadu_si512( a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[8] = _mm512_loadu_si512( a + ( ( ic + 8 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[9] = _mm512_loadu_si512( a + ( ( ic + 9 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[10] = _mm512_loadu_si512( a + ( ( ic + 10 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[11] = _mm512_loadu_si512( a + ( ( ic + 11 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[12] = _mm512_loadu_si512( a + ( ( ic + 12 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[13] = _mm512_loadu_si512( a + ( ( ic + 13 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[14] = _mm512_loadu_si512( a + ( ( ic + 14 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[15] = _mm512_loadu_si512( a + ( ( ic + 15 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr , a_reg[1] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr , a_reg[2] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr , a_reg[3] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr , a_reg[4] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr , a_reg[5] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr , a_reg[6] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr , a_reg[7] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 8 ) * KC ) + kr , a_reg[8] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 9 ) * KC ) + kr , a_reg[9] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 10 ) * KC ) + kr , a_reg[10] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 11 ) * KC ) + kr , a_reg[11] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 12 ) * KC ) + kr , a_reg[12] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 13 ) * KC ) + kr , a_reg[13] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 14 ) * KC ) + kr , a_reg[14] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 15 ) * KC ) + kr , a_reg[15] ); + } + for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) + { + MASKED_LOAD_16_ROWS_AVX512( 0xFFFF ) + MASKED_STORE_16_ROWS_AVX512( 0xFFFF ) + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + MASKED_LOAD_16_ROWS_AVX512( 0xFF ) + + MASKED_STORE_16_ROWS_AVX512( 0xFF ) + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + MASKED_LOAD_16_ROWS_AVX512( 0xF ) + + MASKED_STORE_16_ROWS_AVX512( 0xF ) + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + MASKED_LOAD_16_ROWS_AVX512( 0x3 ) + + MASKED_STORE_16_ROWS_AVX512( 0x3 ) + } + for( ; ( kr ) < KC; kr += 1 ) + { + MASKED_LOAD_16_ROWS_AVX512( 0x1 ) + + MASKED_STORE_16_ROWS_AVX512( 0x1 ) + } + } + for( ; ( ic + 7 - 1 ) < MC; ic += 8 ) + { + for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_loadu_si512( a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[2] = _mm512_loadu_si512( a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[3] = _mm512_loadu_si512( a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[4] = _mm512_loadu_si512( a + ( ( ic + 4 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[5] = _mm512_loadu_si512( a + ( ( ic + 5 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[6] = _mm512_loadu_si512( a + ( ( ic + 6 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[7] = _mm512_loadu_si512( a + ( ( ic + 7 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr , a_reg[1] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr , a_reg[2] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr , a_reg[3] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 4 ) * KC ) + kr , a_reg[4] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 5 ) * KC ) + kr , a_reg[5] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 6 ) * KC ) + kr , a_reg[6] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 7 ) * KC ) + kr , a_reg[7] ); + } + for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) + { + MASKED_LOAD_8_ROWS_AVX512( 0xFFFF ) + MASKED_STORE_8_ROWS_AVX512( 0xFFFF ) + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + MASKED_LOAD_8_ROWS_AVX512( 0xFF ) + + MASKED_STORE_8_ROWS_AVX512( 0xFF ) + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + MASKED_LOAD_8_ROWS_AVX512( 0xF ) + + MASKED_STORE_8_ROWS_AVX512( 0xF ) + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + MASKED_LOAD_8_ROWS_AVX512( 0x3 ) + + MASKED_STORE_8_ROWS_AVX512( 0x3 ) + } + for( ; ( kr ) < KC; kr += 1 ) + { + MASKED_LOAD_8_ROWS_AVX512( 0x1 ) + + MASKED_STORE_8_ROWS_AVX512( 0x1 ) + } + } + for( ; ( ic + 4 - 1 ) < MC; ic += 4 ) + { + for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_loadu_si512( a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[2] = _mm512_loadu_si512( a + ( ( ic + 2 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[3] = _mm512_loadu_si512( a + ( ( ic + 3 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr , a_reg[1] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 2 ) * KC ) + kr , a_reg[2] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 3 ) * KC ) + kr , a_reg[3] ); + } + for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) + { + MASKED_LOAD_4_ROWS_AVX512( 0xFFFF ) + MASKED_STORE_4_ROWS_AVX512( 0xFFFF ) + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + MASKED_LOAD_4_ROWS_AVX512( 0xFF ) + + MASKED_STORE_4_ROWS_AVX512( 0xFF ) + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + MASKED_LOAD_4_ROWS_AVX512( 0xF ) + + MASKED_STORE_4_ROWS_AVX512( 0xF ) + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + MASKED_LOAD_4_ROWS_AVX512( 0x3 ) + + MASKED_STORE_4_ROWS_AVX512( 0x3 ) + } + for( ; ( kr ) < KC; kr += 1 ) + { + MASKED_LOAD_4_ROWS_AVX512( 0x1 ) + + MASKED_STORE_4_ROWS_AVX512( 0x1 ) + } + } + + for( ; ( ic + 2 - 1 ) < MC; ic += 2 ) + { + for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_loadu_si512( a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0] ); + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr , a_reg[1] ); + } + for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xFFFF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_maskz_loadu_epi16( 0xFFFF, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xFFFF, a_reg[0] ); + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, 0xFFFF, a_reg[1] ); + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xFF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_maskz_loadu_epi16( 0xFF, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xFF, a_reg[0] ); + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, 0xFF, a_reg[1] ); + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_maskz_loadu_epi16( 0xF, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xF, a_reg[0] ); + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, 0xF, a_reg[1] ); + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0x3, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_maskz_loadu_epi16( 0x3, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0x3, a_reg[0] ); + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, 0x3, a_reg[1] ); + } + for( ; ( kr ) < KC; kr += 1 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0x1, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + a_reg[1] = _mm512_maskz_loadu_epi16( 0x1, a + ( ( ic + 1 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0x1, a_reg[0] ); + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 1 ) * KC ) + kr, 0x1, a_reg[1] ); + } + } + for( ; ( ic ) < MC; ic += 1 ) + { + for( kr = 0; ( kr + 32 - 1 ) < KC; kr += 32 ) + { + a_reg[0] = _mm512_loadu_si512( a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_storeu_si512( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr , a_reg[0]); + } + for( ; ( kr + 16 - 1 ) < KC; kr += 16 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xFFFF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xFFFF, a_reg[0] ); + } + for( ; ( kr + 7 ) < KC; kr += 8 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xFF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xFF, a_reg[0] ); + } + for( ; ( kr + 3 ) < KC; kr += 4 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0xF, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0xF, a_reg[0] ); + } + for( ; ( kr + 1 ) < KC; kr += 2 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0x3, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0x3, a_reg[0] ); + } + for( ; ( kr ) < KC; kr += 1 ) + { + a_reg[0] = _mm512_maskz_loadu_epi16( 0x1, a + ( ( ic + 0 ) * rs_a ) + ( kr * cs_a ) ); + + _mm512_mask_storeu_epi16( pack_a_buffer + ( ( ic + 0 ) * KC ) + kr, 0x1, a_reg[0] ); + } + } + *rs_p = KC; + *cs_p = 2; + +} +void packa_mr16_bf16bf16f32of32_col_major + ( + bfloat16* pack_a_buffer, + const bfloat16* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t MC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + dim_t MR = 16; + + dim_t m_left = MC % 4; + + __m256i a_reg[16], b_reg[16]; + + dim_t ic, kr; + + for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 8 ) * cs_a ) ) ); + a_reg[9] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 9 ) * cs_a ) ) ); + a_reg[10] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 10 ) * cs_a ) ) ); + a_reg[11] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 11 ) * cs_a ) ) ); + a_reg[12] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 12 ) * cs_a ) ) ); + a_reg[13] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 13 ) * cs_a ) ) ); + a_reg[14] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 14 ) * cs_a ) ) ); + a_reg[15] = _mm256_loadu_si256( (__m256i const *) ( a + ( ic * rs_a ) + ( ( kr + 15 ) * cs_a ) ) ); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[4] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[6] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), a_reg[8] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), a_reg[12] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), a_reg[10] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), a_reg[14] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 8 ) * KC + kr ), a_reg[1] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 9 ) * KC + kr ), a_reg[5] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 10 ) * KC + kr ), a_reg[3] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 11 ) * KC + kr ), a_reg[7] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 12 ) * KC + kr ), a_reg[9] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 13 ) * KC + kr ), a_reg[13] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 14 ) * KC + kr ), a_reg[11] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 15 ) * KC + kr ), a_reg[15] ); + } + + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ) ); + a_reg[5] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ) ); + a_reg[6] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ) ); + a_reg[7] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ) ); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + MASKED_STORE_EPI64(0x03) + + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ) ); + a_reg[3] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ) ); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + MASKED_STORE_EPI64(0x01) + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm256_loadu_si256( (__m256i const *)( a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ) ); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + MASKED_STORE_EPI32(0x01) + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm256_loadu_si256( (__m256i const *)(a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ) ); + a_reg[1] = _mm256_setzero_si256(); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + MASKED_STORE_EPI16(0x01) + } + } + + for( ; ( ic + 8 - 1) < MC; ic += 8) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 8 ) * cs_a ) ); + a_reg[9] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 9 ) * cs_a ) ); + a_reg[10] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 10 ) * cs_a ) ); + a_reg[11] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 11 ) * cs_a ) ); + a_reg[12] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 12 ) * cs_a ) ); + a_reg[13] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 13 ) * cs_a ) ); + a_reg[14] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 14 ) * cs_a ) ); + a_reg[15] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 15 ) * cs_a ) ); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[4] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[6] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 4 ) * KC + kr ), a_reg[8] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 5 ) * KC + kr ), a_reg[12] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 6 ) * KC + kr ), a_reg[10] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 7 ) * KC + kr ), a_reg[14] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 8 ) * KC + kr ), a_reg[1] ); + } + + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x03, a_reg[6] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x03, a_reg[8] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x03, a_reg[12] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x03, a_reg[10] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x03, a_reg[14] ); + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x01, a_reg[8] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x01, a_reg[12] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x01, a_reg[10] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x01, a_reg[14] ); + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x01, a_reg[8] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x01, a_reg[12] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x01, a_reg[10] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x01, a_reg[14] ); + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0xFF, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_setzero_si256(); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 4 ) * KC + kr ), 0x01, a_reg[8] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 5 ) * KC + kr ), 0x01, a_reg[12] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 6 ) * KC + kr ), 0x01, a_reg[10] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 7 ) * KC + kr ), 0x01, a_reg[14] ); + } + } + + for( ; ( ic + 4 - 1 ) < MC; ic += 4) + { + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 8 ) * cs_a ) ); + a_reg[9] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 9 ) * cs_a ) ); + a_reg[10] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 10 ) * cs_a ) ); + a_reg[11] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 11 ) * cs_a ) ); + a_reg[12] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 12 ) * cs_a ) ); + a_reg[13] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 13 ) * cs_a ) ); + a_reg[14] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 14 ) * cs_a ) ); + a_reg[15] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 15 ) * cs_a ) ); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[4] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 3 ) * KC + kr ), a_reg[6] ); + } + + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x03, a_reg[6] ); + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm256_maskz_loadu_epi16( 0x0F, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_setzero_si256(); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 3 ) * KC + kr ), 0x01, a_reg[6] ); + } + } + + if( m_left ) + { + __mmask16 mask = 0xFFFF >> ( 16 - m_left ); + for( kr = 0; ( kr + 15 ) < KC; kr += 16) + { + a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 8 ) * cs_a ) ); + a_reg[9] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 9 ) * cs_a ) ); + a_reg[10] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 10 ) * cs_a ) ); + a_reg[11] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 11 ) * cs_a ) ); + a_reg[12] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 12 ) * cs_a ) ); + a_reg[13] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 13 ) * cs_a ) ); + a_reg[14] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 14 ) * cs_a ) ); + a_reg[15] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 15 ) * cs_a ) ); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + switch( m_left ) + { + case 3: + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[4] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 2 ) * KC + kr ), a_reg[2] ); + break; + case 2: + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 1 ) * KC + kr ), a_reg[4] ); + break; + case 1: + _mm256_storeu_si256( (__m256i *)( pack_a_buffer + ( ic + 0 ) * KC + kr ), a_reg[0] ); + break; + } + } + + for( ; ( kr + 7 ) < KC; kr += 8) + { + a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 4 ) * cs_a ) ); + a_reg[5] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 5 ) * cs_a ) ); + a_reg[6] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 6 ) * cs_a ) ); + a_reg[7] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 7 ) * cs_a ) ); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + switch( m_left ) + { + case 3: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x03, a_reg[2] ); + break; + case 2: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x03, a_reg[4] ); + break; + case 1: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x03, a_reg[0] ); + break; + } + } + for( ; ( kr + 3 ) < KC; kr += 4) + { + a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 2 ) * cs_a ) ); + a_reg[3] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 3 ) * cs_a ) ); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + + switch( m_left ) + { + case 3: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + break; + case 2: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0]); + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4]); + break; + case 1: + _mm256_mask_storeu_epi64( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0]); + break; + } + } + for( ; ( kr + 1 ) < KC; kr += 2) + { + a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 1 ) * cs_a ) ); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + switch( m_left ) + { + case 3: + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + break; + case 2: + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi32( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + break; + case 1: + _mm256_mask_storeu_epi32( (pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + break; + } + } + for( ; ( kr ) < KC; kr += 1) + { + a_reg[0] = _mm256_maskz_loadu_epi16( mask, a + ( ic * rs_a ) + ( ( kr + 0 ) * cs_a ) ); + a_reg[1] = _mm256_setzero_si256(); + a_reg[2] = _mm256_setzero_si256(); + a_reg[3] = _mm256_setzero_si256(); + a_reg[4] = _mm256_setzero_si256(); + a_reg[5] = _mm256_setzero_si256(); + a_reg[6] = _mm256_setzero_si256(); + a_reg[7] = _mm256_setzero_si256(); + a_reg[8] = _mm256_setzero_si256(); + a_reg[9] = _mm256_setzero_si256(); + a_reg[10] = _mm256_setzero_si256(); + a_reg[11] = _mm256_setzero_si256(); + a_reg[12] = _mm256_setzero_si256(); + a_reg[13] = _mm256_setzero_si256(); + a_reg[14] = _mm256_setzero_si256(); + a_reg[15] = _mm256_setzero_si256(); + + UNPACKLO_EPI16 + UNPACKHI_EPI16 + UNPACKLO_EPI32 + UNPACKHI_EPI32 + UNPACKLO_EPI64 + UNPACKHI_EPI64 + SHUFFLE_64x2 + switch( m_left ) + { + case 3: + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 2 ) * KC + kr ), 0x01, a_reg[2] ); + break; + case 2: + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 1 ) * KC + kr ), 0x01, a_reg[4] ); + break; + case 1: + _mm256_mask_storeu_epi16( ( pack_a_buffer + ( ic + 0 ) * KC + kr ), 0x01, a_reg[0] ); + break; + } + } + } + + *rs_p = KC; + *cs_p = 2; +} +#endif diff --git a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c index fe39c8c038..54d0fb86b8 100644 --- a/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c +++ b/kernels/zen4/lpgemm/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,50 +38,116 @@ #ifdef BLIS_ADDON_LPGEMM -void packb_nrlt16_bf16bf16f32of32 + +void packb_nr64_bf16bf16f32of32_row_major + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr64_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nrlt16_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC, + const dim_t KC, const dim_t n0_partial_rem ); -void packb_nr16_bf16bf16f32of32 +void packb_nr16_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ); -void packb_nr32_bf16bf16f32of32 +void packb_nr32_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ); -void packb_nr48_bf16bf16f32of32 +void packb_nr48_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ); + +void packb_nrlt16_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr_mult_16_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer, + const bfloat16* b, + const dim_t NR, + const dim_t ldb, + const dim_t KC + ); + + void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t rs_b, + const dim_t cs_b, + const dim_t NC, + const dim_t KC, + dim_t* rs_p, + dim_t* cs_p + ) +{ + if( cs_b == 1 ) + { + packb_nr64_bf16bf16f32of32_row_major( pack_b_buffer_bf16bf16f32of32, + b, rs_b, NC, KC, rs_p, cs_p ); + } + else + { + packb_nr64_bf16bf16f32of32_col_major( pack_b_buffer_bf16bf16f32of32, + b, cs_b, NC, KC, rs_p, cs_p ); + } +} +void packb_nr64_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t NC, - const dim_t KC, + const dim_t NC, + const dim_t KC, dim_t* rs_b, dim_t* cs_b ) -{ +{ dim_t NR = 64; // Used for permuting the mm512i elements for use in dpbf16_ps instruction. @@ -111,7 +177,7 @@ void packb_nr64_bf16bf16f32of32 } for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) - { + { for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) { // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. @@ -131,12 +197,12 @@ void packb_nr64_bf16bf16f32of32 a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); - //store to pack_b buffer + //store to pack_b buffer _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ), b0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) + 32, a0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ), d0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) + 32, c0 ); - } + } // Handle k remainder. if( k_partial_pieces > 0) { @@ -156,12 +222,12 @@ void packb_nr64_bf16bf16f32of32 a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); - //store to pack_b buffer + //store to pack_b buffer _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ), b0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) + 32, a0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ), d0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) + 32, c0 ); - } + } } if(n_partial_pieces > 0) @@ -178,64 +244,64 @@ void packb_nr64_bf16bf16f32of32 if ( n0_48 == 1 ) { - packb_nr48_bf16bf16f32of32 + packb_nr48_bf16bf16f32of32_row_major ( ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), ( b + n_full_pieces_loop_limit ), ldb, KC ); n0_partial_pack = 48; - } + } else if ( n0_32 == 1 ) { - packb_nr32_bf16bf16f32of32 - ( + packb_nr32_bf16bf16f32of32_row_major + ( ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), ( b + n_full_pieces_loop_limit ), ldb, KC ); n0_partial_pack = 32; - } + } else if ( n0_16 == 1 ) { - packb_nr16_bf16bf16f32of32 + packb_nr16_bf16bf16f32of32_row_major ( ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), ( b + n_full_pieces_loop_limit ), ldb, KC ); n0_partial_pack = 16; - } + } if ( n0_partial_rem > 0 ) { - packb_nrlt16_bf16bf16f32of32 + packb_nrlt16_bf16bf16f32of32_row_major ( ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) + ( n0_partial_pack * KC_updated ) ), ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, n0_partial_rem ); - } - } + } + } *rs_b = NR * 2; *cs_b = NR / 2; } -void packb_nr48_bf16bf16f32of32 +void packb_nr48_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ) -{ +{ dim_t NR1 = 32; dim_t NR2 = 16; // Used for permuting the mm512i elements for use in dpbf16_ps instruction. __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); - __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); __m512i a0x; __m512i b0x; @@ -256,21 +322,21 @@ void packb_nr48_bf16bf16f32of32 for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) { // Rearrange for dpbf16_ps, read 2 rows from B with 32 elements in each row. - a0x = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) ); + a0x = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) ); c0x = _mm512_loadu_si512( b + ( ldb * ( kr + 1 ) ) ); a01x = _mm512_unpacklo_epi16( a0x, c0x ); a0x = _mm512_unpackhi_epi16( a0x, c0x ); - b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); //First 2x32 elements - _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); - _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. - a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 0 ) ) + NR1 ); + a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 0 ) ) + NR1 ); c0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 1 ) ) + NR1 ); a01 = _mm256_unpacklo_epi16( a0, c0 ); @@ -279,7 +345,7 @@ void packb_nr48_bf16bf16f32of32 b0 = _mm256_permute2f128_si256(a01, a0, 0x20); a0 = _mm256_permute2f128_si256(a01, a0, 0x31); - //Last 2x16 elements + //Last 2x16 elements _mm256_mask_storeu_epi64 ( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), @@ -296,20 +362,20 @@ void packb_nr48_bf16bf16f32of32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - a0x = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) ); + a0x = _mm512_loadu_si512( b + ( ldb * ( k_full_pieces + 0 ) ) ); c0x = _mm512_setzero_si512(); a01x = _mm512_unpacklo_epi16( a0x, c0x ); a0x = _mm512_unpackhi_epi16( a0x, c0x ); - b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); //First 2x32 elements - _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); - _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); - a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 ); + a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 ); c0 = _mm256_setzero_si256(); a01 = _mm256_unpacklo_epi16( a0, c0 ); @@ -318,7 +384,7 @@ void packb_nr48_bf16bf16f32of32 b0 = _mm256_permute2f128_si256(a01, a0, 0x20); a0 = _mm256_permute2f128_si256(a01, a0, 0x31); - //Last 2x16 elements + //Last 2x16 elements _mm256_mask_storeu_epi64 ( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), @@ -332,12 +398,12 @@ void packb_nr48_bf16bf16f32of32 } } -void packb_nr32_bf16bf16f32of32 +void packb_nr32_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ) { dim_t NR = 32; @@ -373,7 +439,7 @@ void packb_nr32_bf16bf16f32of32 _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); kr_new += 2; - } + } // Handle k remainder. if ( k_partial_pieces > 0 ) { @@ -389,14 +455,14 @@ void packb_nr32_bf16bf16f32of32 _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new ) * NR ), b0 ); _mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); } -} +} -void packb_nr16_bf16bf16f32of32 +void packb_nr16_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC + const dim_t KC ) { dim_t NR = 16; @@ -413,12 +479,12 @@ void packb_nr16_bf16bf16f32of32 dim_t kr_new = 0; for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) - { + { // Rearrange for dpbf16_ps, read 2 rows from B with 16 elements in each row. a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 0 ) ) ); - c0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 1 ) ) ); + c0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( kr + 1 ) ) ); - a01 = _mm256_unpacklo_epi16( a0, c0 ); + a01 = _mm256_unpacklo_epi16( a0, c0 ); a0 = _mm256_unpackhi_epi16( a0, c0 ); b0 = _mm256_permute2f128_si256(a01, a0, 0x20); @@ -443,7 +509,7 @@ void packb_nr16_bf16bf16f32of32 a0 = _mm256_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( k_full_pieces + 0 ) ) ); c0 = _mm256_setzero_si256(); - a01 = _mm256_unpacklo_epi16( a0, c0 ); + a01 = _mm256_unpacklo_epi16( a0, c0 ); a0 = _mm256_unpackhi_epi16( a0, c0 ); b0 = _mm256_permute2f128_si256(a01, a0, 0x20); @@ -459,15 +525,15 @@ void packb_nr16_bf16bf16f32of32 pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), 0xFF, a0 ); - } -} + } +} -void packb_nrlt16_bf16bf16f32of32 +void packb_nrlt16_bf16bf16f32of32_row_major ( bfloat16* pack_b_buffer_bf16bf16f32of32, const bfloat16* b, const dim_t ldb, - const dim_t KC, + const dim_t KC, const dim_t n0_partial_rem ) { @@ -488,14 +554,14 @@ void packb_nrlt16_bf16bf16f32of32 bfloat16 buf1[16]; for ( int kr = 0; kr < k_full_pieces; kr += 2 ) - { + { memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 ); c0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf1 ); - a01 = _mm256_unpacklo_epi16( a0, c0 ); + a01 = _mm256_unpacklo_epi16( a0, c0 ); a0 = _mm256_unpackhi_epi16( a0, c0 ); b0 = _mm256_permute2f128_si256(a01, a0, 0x20); @@ -521,7 +587,7 @@ void packb_nrlt16_bf16bf16f32of32 a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 ); c0 = _mm256_setzero_si256(); - a01 = _mm256_unpacklo_epi16( a0, c0 ); + a01 = _mm256_unpacklo_epi16( a0, c0 ); a0 = _mm256_unpackhi_epi16( a0, c0 ); b0 = _mm256_permute2f128_si256(a01, a0, 0x20); @@ -537,6 +603,517 @@ void packb_nrlt16_bf16bf16f32of32 pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), 0xFF, a0 ); - } + } +} + +#define LOAD_16_COLS_AVX512 \ + a_reg[0] = _mm512_loadu_si512(b + ( ldb * ( jr + 0 ) ) + kr); \ + a_reg[1] = _mm512_loadu_si512(b + ( ldb * ( jr + 1 ) ) + kr); \ + a_reg[2] = _mm512_loadu_si512(b + ( ldb * ( jr + 2 ) ) + kr); \ + a_reg[3] = _mm512_loadu_si512(b + ( ldb * ( jr + 3 ) ) + kr); \ + a_reg[4] = _mm512_loadu_si512(b + ( ldb * ( jr + 4 ) ) + kr); \ + a_reg[5] = _mm512_loadu_si512(b + ( ldb * ( jr + 5 ) ) + kr); \ + a_reg[6] = _mm512_loadu_si512(b + ( ldb * ( jr + 6 ) ) + kr); \ + a_reg[7] = _mm512_loadu_si512(b + ( ldb * ( jr + 7 ) ) + kr); \ + a_reg[8] = _mm512_loadu_si512(b + ( ldb * ( jr + 8 ) ) + kr); \ + a_reg[9] = _mm512_loadu_si512(b + ( ldb * ( jr + 9 ) ) + kr); \ + a_reg[10] = _mm512_loadu_si512(b + ( ldb * ( jr + 10 ) ) + kr); \ + a_reg[11] = _mm512_loadu_si512(b + ( ldb * ( jr + 11 ) ) + kr); \ + a_reg[12] = _mm512_loadu_si512(b + ( ldb * ( jr + 12 ) ) + kr); \ + a_reg[13] = _mm512_loadu_si512(b + ( ldb * ( jr + 13 ) ) + kr); \ + a_reg[14] = _mm512_loadu_si512(b + ( ldb * ( jr + 14 ) ) + kr); \ + a_reg[15] = _mm512_loadu_si512(b + ( ldb * ( jr + 15 ) ) + kr); + +#define UNPACKHILO32_AVX512 \ + b_reg[0] = _mm512_unpacklo_epi32(a_reg[0], a_reg[1]); \ + b_reg[2] = _mm512_unpacklo_epi32(a_reg[2], a_reg[3]); \ + b_reg[4] = _mm512_unpacklo_epi32(a_reg[4], a_reg[5]); \ + b_reg[6] = _mm512_unpacklo_epi32(a_reg[6], a_reg[7]); \ + b_reg[8] = _mm512_unpacklo_epi32(a_reg[8], a_reg[9]); \ + b_reg[10] = _mm512_unpacklo_epi32(a_reg[10], a_reg[11]); \ + b_reg[12] = _mm512_unpacklo_epi32(a_reg[12], a_reg[13]); \ + b_reg[14] = _mm512_unpacklo_epi32(a_reg[14], a_reg[15]); \ +\ + b_reg[1] = _mm512_unpackhi_epi32(a_reg[0], a_reg[1]); \ + b_reg[3] = _mm512_unpackhi_epi32(a_reg[2], a_reg[3]); \ + b_reg[5] = _mm512_unpackhi_epi32(a_reg[4], a_reg[5]); \ + b_reg[7] = _mm512_unpackhi_epi32(a_reg[6], a_reg[7]); \ + b_reg[9] = _mm512_unpackhi_epi32(a_reg[8], a_reg[9]); \ + b_reg[11] = _mm512_unpackhi_epi32(a_reg[10], a_reg[11]); \ + b_reg[13] = _mm512_unpackhi_epi32(a_reg[12], a_reg[13]); \ + b_reg[15] = _mm512_unpackhi_epi32(a_reg[14], a_reg[15]); + +#define UNPACKHILO64_AVX512 \ + a_reg[0] = _mm512_unpacklo_epi64(b_reg[0], b_reg[2]); \ + a_reg[1] = _mm512_unpacklo_epi64(b_reg[4], b_reg[6]); \ + a_reg[2] = _mm512_unpacklo_epi64(b_reg[8], b_reg[10]); \ + a_reg[3] = _mm512_unpacklo_epi64(b_reg[12], b_reg[14]); \ + a_reg[4] = _mm512_unpacklo_epi64(b_reg[1], b_reg[3]); \ + a_reg[5] = _mm512_unpacklo_epi64(b_reg[5], b_reg[7]); \ + a_reg[6] = _mm512_unpacklo_epi64(b_reg[9], b_reg[11]); \ + a_reg[7] = _mm512_unpacklo_epi64(b_reg[13], b_reg[15]); \ +\ + a_reg[8] = _mm512_unpackhi_epi64(b_reg[0], b_reg[2]); \ + a_reg[9] = _mm512_unpackhi_epi64(b_reg[4], b_reg[6]); \ + a_reg[10] = _mm512_unpackhi_epi64(b_reg[8], b_reg[10]); \ + a_reg[11] = _mm512_unpackhi_epi64(b_reg[12], b_reg[14]); \ + a_reg[12] = _mm512_unpackhi_epi64(b_reg[1], b_reg[3]); \ + a_reg[13] = _mm512_unpackhi_epi64(b_reg[5], b_reg[7]); \ + a_reg[14] = _mm512_unpackhi_epi64(b_reg[9], b_reg[11]); \ + a_reg[15] = _mm512_unpackhi_epi64(b_reg[13], b_reg[15]); + +#define PERMUTEX2_VAR64_AVX512 \ + b_reg[0] = _mm512_permutex2var_epi64(a_reg[0], selector1, a_reg[1]); \ + b_reg[1] = _mm512_permutex2var_epi64(a_reg[2], selector1, a_reg[3]); \ + b_reg[2] = _mm512_permutex2var_epi64(a_reg[8], selector1, a_reg[9]); \ + b_reg[3] = _mm512_permutex2var_epi64(a_reg[10], selector1, a_reg[11]); \ + b_reg[4] = _mm512_permutex2var_epi64(a_reg[4], selector1, a_reg[5]); \ + b_reg[5] = _mm512_permutex2var_epi64(a_reg[6], selector1, a_reg[7]); \ + b_reg[6] = _mm512_permutex2var_epi64(a_reg[12], selector1, a_reg[13]); \ + b_reg[7] = _mm512_permutex2var_epi64(a_reg[14], selector1, a_reg[15]); \ + b_reg[8] = _mm512_permutex2var_epi64(a_reg[0], selector2, a_reg[1]); \ + b_reg[9] = _mm512_permutex2var_epi64(a_reg[2], selector2, a_reg[3]); \ + b_reg[10] = _mm512_permutex2var_epi64(a_reg[8], selector2, a_reg[9]); \ + b_reg[11] = _mm512_permutex2var_epi64(a_reg[10], selector2, a_reg[11]); \ + b_reg[12] = _mm512_permutex2var_epi64(a_reg[4], selector2, a_reg[5]); \ + b_reg[13] = _mm512_permutex2var_epi64(a_reg[6], selector2, a_reg[7]); \ + b_reg[14] = _mm512_permutex2var_epi64(a_reg[12], selector2, a_reg[13]); \ + b_reg[15] = _mm512_permutex2var_epi64(a_reg[14], selector2, a_reg[15]); + +#define SHUFFLE64x2_AVX512 \ + a_reg[0] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0x44); \ + a_reg[1] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0x44); \ + a_reg[2] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0x44); \ + a_reg[3] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0x44); \ + a_reg[4] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0x44); \ + a_reg[5] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0x44); \ + a_reg[6] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0x44); \ + a_reg[7] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0x44); \ + a_reg[8] = _mm512_shuffle_i64x2(b_reg[0], b_reg[1], 0xEE); \ + a_reg[9] = _mm512_shuffle_i64x2(b_reg[2], b_reg[3], 0xEE); \ + a_reg[10] = _mm512_shuffle_i64x2(b_reg[4], b_reg[5], 0xEE); \ + a_reg[11] = _mm512_shuffle_i64x2(b_reg[6], b_reg[7], 0xEE); \ + a_reg[12] = _mm512_shuffle_i64x2(b_reg[8], b_reg[9], 0xEE); \ + a_reg[13] = _mm512_shuffle_i64x2(b_reg[10], b_reg[11], 0xEE); \ + a_reg[14] = _mm512_shuffle_i64x2(b_reg[12], b_reg[13], 0xEE); \ + a_reg[15] = _mm512_shuffle_i64x2(b_reg[14], b_reg[15], 0xEE); + +#define MASK_LOAD_16_COLS_AVX512(mask) \ + a_reg[0] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 0 ) ) + kr); \ + a_reg[1] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 1 ) ) + kr); \ + a_reg[2] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 2 ) ) + kr); \ + a_reg[3] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 3 ) ) + kr); \ + a_reg[4] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 4 ) ) + kr); \ + a_reg[5] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 5 ) ) + kr); \ + a_reg[6] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 6 ) ) + kr); \ + a_reg[7] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 7 ) ) + kr); \ + a_reg[8] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 8 ) ) + kr); \ + a_reg[9] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 9 ) ) + kr); \ + a_reg[10] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 10 ) ) + kr); \ + a_reg[11] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 11 ) ) + kr); \ + a_reg[12] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 12 ) ) + kr); \ + a_reg[13] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 13 ) ) + kr); \ + a_reg[14] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 14 ) ) + kr); \ + a_reg[15] = _mm512_maskz_loadu_epi16( mask, b + ( ldb * ( jr + 15 ) ) + kr); + +void packb_nr64_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ) +{ + dim_t NR = 64; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_partial_pieces = KC % 2; + + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + packb_nr_mult_16_bf16bf16f32of32_col_major + ( pack_b_buffer + (jc * KC_updated), + b + (jc * ldb), 64, ldb, KC + ); + } + + if(n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr_mult_16_bf16bf16f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit * ldb ), 48, ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr_mult_16_bf16bf16f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit * ldb ), 32, ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr_mult_16_bf16bf16f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit * ldb ), 16, ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16bf16f32of32_col_major + ( + ( pack_b_buffer + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + ( n_full_pieces_loop_limit + n0_partial_pack ) * ldb ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 2; + *cs_b = NR / 2; +} + +void packb_nr_mult_16_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer, + const bfloat16* b, + const dim_t NR, + const dim_t ldb, + const dim_t KC + ) +{ + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD ); + __m512i selector2 = _mm512_setr_epi64( 0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF ); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + dim_t kr = 0; + for ( kr = 0; ( kr + 31 ) < KC; kr += 32 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + + LOAD_16_COLS_AVX512 + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 6 ) * NR ), a_reg[3] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 8 ) * NR ), a_reg[4] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 10 ) * NR ), a_reg[5] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 12 ) * NR ), a_reg[6] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 14 ) * NR ), a_reg[7] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 16 ) * NR ), a_reg[8] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 18 ) * NR ), a_reg[9] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 20 ) * NR ), a_reg[10] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 22 ) * NR ), a_reg[11] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 24 ) * NR ), a_reg[12] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 26 ) * NR ), a_reg[13] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 28 ) * NR ), a_reg[14] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 30 ) * NR ), a_reg[15] ); + + } + } + for ( ; ( kr + 15 ) < KC; kr += 16 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + + MASK_LOAD_16_COLS_AVX512( 0xFFFF ) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 6 ) * NR ), a_reg[3] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 8 ) * NR ), a_reg[4] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 10 ) * NR ), a_reg[5] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 12 ) * NR ), a_reg[6] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 14 ) * NR ), a_reg[7] ); + } + } + + for( ; ( kr +7 ) < KC; kr += 8 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + + MASK_LOAD_16_COLS_AVX512( 0xFF ) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 6 ) * NR ), a_reg[3] ); + } + } + for( ; ( kr +3 ) < KC; kr += 4 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512( 0x0F ) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr + 2 ) * NR ), a_reg[1] ); + } + } + for( ; ( kr +1 ) < KC; kr += 2 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512( 0x03 ) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( ( kr ) * NR ), a_reg[0] ); + } + } + for( ; kr < KC; kr += 1 ) + { + for( dim_t jr = 0; jr < NR; jr += 16 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + MASK_LOAD_16_COLS_AVX512( 0x01 ) + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( jr * 2 ) + ( kr * NR ), a_reg[0] ); + } + } +} + + +void packb_nrlt16_bf16bf16f32of32_col_major + ( + bfloat16* pack_b_buffer, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + dim_t NR = 16; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xC, 0xD ); + __m512i selector2 = _mm512_setr_epi64( 0x2, 0x3, 0xA, 0xB, 0x6, 0x7, 0xE, 0xF ); + + __m512i a_reg[16]; + __m512i b_reg[16]; + + dim_t kr = 0, jr = 0; + for ( kr = 0; ( kr + 31 ) < KC; kr += 32 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_loadu_si512( b + ( ldb * ( jr + 0 ) ) + kr ); + } + for(; jr < NR; jr++) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 6 ) * NR ), a_reg[3] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 8 ) * NR ), a_reg[4] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 10 ) * NR ), a_reg[5] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 12 ) * NR ), a_reg[6] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 14 ) * NR ), a_reg[7] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 16 ) * NR ), a_reg[8] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 18 ) * NR ), a_reg[9] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 20 ) * NR ), a_reg[10] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 22 ) * NR ), a_reg[11] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 24 ) * NR ), a_reg[12] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 26 ) * NR ), a_reg[13] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 28 ) * NR ), a_reg[14] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 30 ) * NR ), a_reg[15] ); + + } + for ( ; ( kr + 15 ) < KC; kr += 16 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16( 0xFFFF, b + ( ldb * ( jr + 0 ) ) + kr ); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 6 ) * NR ), a_reg[3] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 8 ) * NR ), a_reg[4] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 10 ) * NR ), a_reg[5] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 12 ) * NR ), a_reg[6] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 14 ) * NR ), a_reg[7] ); + } + + for ( ; ( kr + 7 ) < KC; kr += 8 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16( 0xFF, b + ( ldb * ( jr + 0 ) ) + kr ); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 2 ) * NR ), a_reg[1] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 4 ) * NR ), a_reg[2] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 6 ) * NR ), a_reg[3] ); + } + for ( ; (kr+3) < KC; kr += 4 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16( 0x0F, b + ( ldb * ( jr + 0 ) ) + kr ); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 0 ) * NR ), a_reg[0] ); + _mm512_storeu_si512( pack_b_buffer + ( ( kr + 2 ) * NR ), a_reg[1] ); + } + for ( ; ( kr + 1 ) < KC; kr += 2 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16( 0x03, b + ( ldb * ( jr + 0 ) ) + kr ); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm512_setzero_si512(); + } + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( kr * NR ), a_reg[0] ); + } + for ( ; kr < KC; kr += 1 ) + { + for( jr = 0; jr < n0_partial_rem; jr += 1 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a_reg[jr] = _mm512_maskz_loadu_epi16( 0x01, b + ( ldb * ( jr + 0 ) ) + kr ); + } + for( ; jr < NR; jr++ ) + { + a_reg[jr] = _mm512_setzero_si512(); + } + + UNPACKHILO32_AVX512 + UNPACKHILO64_AVX512 + PERMUTEX2_VAR64_AVX512 + SHUFFLE64x2_AVX512 + + // store to pack_b buffer + _mm512_storeu_si512( pack_b_buffer + ( kr * NR ), a_reg[0] ); + } } #endif diff --git a/kernels/zen4/lpgemm/math_utils_avx512.h b/kernels/zen4/lpgemm/math_utils_avx512.h index 82c9c5650b..dddfd58825 100644 --- a/kernels/zen4/lpgemm/math_utils_avx512.h +++ b/kernels/zen4/lpgemm/math_utils_avx512.h @@ -44,8 +44,8 @@ #define TBL_LN2 0x1.71547652b82fep+0 #define EXPF_HUGE 0x1.8p+23 -#define EXPF_MIN -88.7228393f -#define EXPF_MAX 88.7228393f +#define EXPF_MIN -88.0f +#define EXPF_MAX 88.0f #define inf 1.0/0.0 #define sign -2147483648 @@ -113,7 +113,9 @@ POLY_EVAL_HORNER_16_0_AVX512(r,x); \ \ x = (__m512)_mm512_mask_xor_epi32 ((__m512i)_mm512_set1_ps(1), _mm512_cmpnle_ps_mask \ - ( _mm512_set1_ps(3.9192059040069580078125f), r), (__m512i)x, _mm512_set1_epi32(0)); \ + ( _mm512_set1_ps(3.553f), r), (__m512i)x, _mm512_set1_epi32(0)); \ + x = (__m512)_mm512_mask_xor_epi32 ((__m512i)_mm512_set1_ps(1), _mm512_cmpnle_ps_mask \ + ( _mm512_set1_ps(1.0f), x), (__m512i)x, _mm512_set1_epi32(0)); \ x_erf = (__m512)_mm512_or_epi32(_mm512_and_epi32 ((__m512i)x_erf, _mm512_set1_epi32(~(0x7FFFFFFF))), (__m512i)x); #endif // AOCL_LPGEMM_MATH_UTILS_AVX512_H diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c index a2e487bcb3..df5d29472c 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_6x64rowmajor_s8_amd512vnni.c @@ -1060,77 +1060,95 @@ LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64) _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + // int8_t zero point value. + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[4, 48-63] - CVT_MULRND_CVT32(c_int32_4p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); // c[5, 32-47] - CVT_MULRND_CVT32(c_int32_5p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_5p2,a_int32_0,zero_point2); // c[5, 48-63] - CVT_MULRND_CVT32(c_int32_5p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_5p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c index a338484df6..53a0f51d17 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_m_fringe_s8_amd512vnni.c @@ -825,66 +825,82 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[4, 48-63] - CVT_MULRND_CVT32(c_int32_4p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1684,54 +1700,70 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2381,42 +2413,58 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2918,30 +2966,46 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3292,18 +3356,34 @@ LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c index c009bdeaf3..ced733e131 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_mn_fringe_s8_amd512vnni.c @@ -432,21 +432,27 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); // c[4, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -835,18 +841,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1177,15 +1189,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1458,12 +1476,18 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1678,9 +1702,15 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2072,21 +2102,25 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2451,18 +2485,22 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2771,15 +2809,19 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3032,12 +3074,16 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3234,9 +3280,13 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3760,36 +3810,44 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4294,30 +4352,38 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4737,24 +4803,32 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -5089,18 +5163,26 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -5350,12 +5432,20 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -6012,51 +6102,63 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -6695,42 +6797,54 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -7255,33 +7369,45 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -7693,24 +7819,36 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -8008,15 +8146,27 @@ LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c b/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c index b88ef512d6..9669b638b5 100644 --- a/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c +++ b/kernels/zen4/lpgemm/s8s8s32/lpgemm_n_fringe_s8_amd512vnni.c @@ -524,24 +524,30 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); // c[4, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); // c[5, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1129,24 +1135,28 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1884,42 +1894,50 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2843,60 +2861,72 @@ LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); // c[5, 32-47] - CVT_MULRND_CVT32(c_int32_5p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_5p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index f79cd8775a..32bfc2c8af 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -906,77 +906,95 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + // int8_t zero point value. + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); + // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[4, 48-63] - CVT_MULRND_CVT32(c_int32_4p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); // c[5, 32-47] - CVT_MULRND_CVT32(c_int32_5p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_5p2,a_int32_0,zero_point2); // c[5, 48-63] - CVT_MULRND_CVT32(c_int32_5p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_5p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c index bcaa2d81c3..23393cad4f 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -752,66 +752,82 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[4, 48-63] - CVT_MULRND_CVT32(c_int32_4p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_4p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1547,54 +1563,70 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[3, 48-63] - CVT_MULRND_CVT32(c_int32_3p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_3p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2191,42 +2223,58 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[2, 48-63] - CVT_MULRND_CVT32(c_int32_2p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_2p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2685,30 +2733,46 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[1, 48-63] - CVT_MULRND_CVT32(c_int32_1p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_1p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3027,18 +3091,34 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) a_int32_1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 3 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); + __m128i zero_point3 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 3 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[0, 48-63] - CVT_MULRND_CVT32(c_int32_0p3,a_int32_1); + CVT_MULRND_CVT32(c_int32_0p3,a_int32_1,zero_point3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 940d9e92fa..3dcb7eed07 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -402,21 +402,27 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); // c[4, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -779,18 +785,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1098,15 +1110,21 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1359,12 +1377,18 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1562,9 +1586,15 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1928,21 +1958,25 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2282,18 +2316,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2580,15 +2618,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2822,12 +2864,16 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3008,9 +3054,13 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3498,36 +3548,44 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4000,30 +4058,38 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4415,24 +4481,32 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4743,18 +4817,26 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4984,12 +5066,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -5602,51 +5692,63 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -6246,42 +6348,54 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -6772,33 +6886,45 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -7180,24 +7306,36 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -7470,15 +7608,27 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c index f59c82721c..bfe3fb6ce1 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_extMR_fringe_amd512vnni.c @@ -783,42 +783,48 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); // c[4, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); // c[5, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1,zero_point); // c[6, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_6p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_6p0,selector1,zero_point); // c[7, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_7p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_7p0,selector1,zero_point); // c[8, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_8p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_8p0,selector1,zero_point); // c[9, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_9p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_9p0,selector1,zero_point); // c[10, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_10p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_10p0,selector1,zero_point); // c[11, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_11p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_11p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1664,42 +1670,46 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[6, 0-15] - CVT_MULRND_CVT32(c_int32_6p0,selector1); + CVT_MULRND_CVT32(c_int32_6p0,selector1,zero_point0); // c[7, 0-15] - CVT_MULRND_CVT32(c_int32_7p0,selector1); + CVT_MULRND_CVT32(c_int32_7p0,selector1,zero_point0); // c[8, 0-15] - CVT_MULRND_CVT32(c_int32_8p0,selector1); + CVT_MULRND_CVT32(c_int32_8p0,selector1,zero_point0); // c[9, 0-15] - CVT_MULRND_CVT32(c_int32_9p0,selector1); + CVT_MULRND_CVT32(c_int32_9p0,selector1,zero_point0); // c[10, 0-15] - CVT_MULRND_CVT32(c_int32_10p0,selector1); + CVT_MULRND_CVT32(c_int32_10p0,selector1,zero_point0); // c[11, 0-15] - CVT_MULRND_CVT32(c_int32_11p0,selector1); + CVT_MULRND_CVT32(c_int32_11p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2569,60 +2579,68 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); // c[6, 0-15] - CVT_MULRND_CVT32(c_int32_6p0,selector1); + CVT_MULRND_CVT32(c_int32_6p0,selector1,zero_point0); // c[6, 16-31] - CVT_MULRND_CVT32(c_int32_6p1,selector2); + CVT_MULRND_CVT32(c_int32_6p1,selector2,zero_point1); // c[7, 0-15] - CVT_MULRND_CVT32(c_int32_7p0,selector1); + CVT_MULRND_CVT32(c_int32_7p0,selector1,zero_point0); // c[7, 16-31] - CVT_MULRND_CVT32(c_int32_7p1,selector2); + CVT_MULRND_CVT32(c_int32_7p1,selector2,zero_point1); // c[8, 0-15] - CVT_MULRND_CVT32(c_int32_8p0,selector1); + CVT_MULRND_CVT32(c_int32_8p0,selector1,zero_point0); // c[8, 16-31] - CVT_MULRND_CVT32(c_int32_8p1,selector2); + CVT_MULRND_CVT32(c_int32_8p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c index d5f86338a6..f3574e5dc0 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -473,24 +473,30 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) ( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j ) ); + __m128i zero_point = _mm_maskz_loadu_epi8 + ( + load_mask, + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j ) + ); // c[0, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_0p0,selector1,zero_point); // c[1, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_1p0,selector1,zero_point); // c[2, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_2p0,selector1,zero_point); // c[3, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_3p0,selector1,zero_point); // c[4, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_4p0,selector1,zero_point); // c[5, 0-15] - CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1); + CVT_MULRND_CVT32_LT16(c_int32_5p0,selector1,zero_point); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1026,24 +1032,28 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) selector1 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 0 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1724,42 +1734,50 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) selector2 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 1 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2609,60 +2627,72 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) a_int32_0 = _mm512_loadu_si512( ( float* )post_ops_list_temp->scale_factor + post_ops_attr.post_op_c_j + ( 2 * 16 ) ); + __m128i zero_point0 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 0 * 16 ) ) ); + __m128i zero_point1 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 1 * 16 ) ) ); + __m128i zero_point2 = + _mm_loadu_si128( ( __m128i const* ) + ( ( int8_t* )post_ops_list_temp->op_args1 + + post_ops_attr.post_op_c_j + ( 2 * 16 ) ) ); // c[0, 0-15] - CVT_MULRND_CVT32(c_int32_0p0,selector1); + CVT_MULRND_CVT32(c_int32_0p0,selector1,zero_point0); // c[0, 16-31] - CVT_MULRND_CVT32(c_int32_0p1,selector2); + CVT_MULRND_CVT32(c_int32_0p1,selector2,zero_point1); // c[0, 32-47] - CVT_MULRND_CVT32(c_int32_0p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_0p2,a_int32_0,zero_point2); // c[1, 0-15] - CVT_MULRND_CVT32(c_int32_1p0,selector1); + CVT_MULRND_CVT32(c_int32_1p0,selector1,zero_point0); // c[1, 16-31] - CVT_MULRND_CVT32(c_int32_1p1,selector2); + CVT_MULRND_CVT32(c_int32_1p1,selector2,zero_point1); // c[1, 32-47] - CVT_MULRND_CVT32(c_int32_1p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_1p2,a_int32_0,zero_point2); // c[2, 0-15] - CVT_MULRND_CVT32(c_int32_2p0,selector1); + CVT_MULRND_CVT32(c_int32_2p0,selector1,zero_point0); // c[2, 16-31] - CVT_MULRND_CVT32(c_int32_2p1,selector2); + CVT_MULRND_CVT32(c_int32_2p1,selector2,zero_point1); // c[2, 32-47] - CVT_MULRND_CVT32(c_int32_2p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_2p2,a_int32_0,zero_point2); // c[3, 0-15] - CVT_MULRND_CVT32(c_int32_3p0,selector1); + CVT_MULRND_CVT32(c_int32_3p0,selector1,zero_point0); // c[3, 16-31] - CVT_MULRND_CVT32(c_int32_3p1,selector2); + CVT_MULRND_CVT32(c_int32_3p1,selector2,zero_point1); // c[3, 32-47] - CVT_MULRND_CVT32(c_int32_3p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_3p2,a_int32_0,zero_point2); // c[4, 0-15] - CVT_MULRND_CVT32(c_int32_4p0,selector1); + CVT_MULRND_CVT32(c_int32_4p0,selector1,zero_point0); // c[4, 16-31] - CVT_MULRND_CVT32(c_int32_4p1,selector2); + CVT_MULRND_CVT32(c_int32_4p1,selector2,zero_point1); // c[4, 32-47] - CVT_MULRND_CVT32(c_int32_4p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_4p2,a_int32_0,zero_point2); // c[5, 0-15] - CVT_MULRND_CVT32(c_int32_5p0,selector1); + CVT_MULRND_CVT32(c_int32_5p0,selector1,zero_point0); // c[5, 16-31] - CVT_MULRND_CVT32(c_int32_5p1,selector2); + CVT_MULRND_CVT32(c_int32_5p1,selector2,zero_point1); // c[5, 32-47] - CVT_MULRND_CVT32(c_int32_5p2,a_int32_0); + CVT_MULRND_CVT32(c_int32_5p2,a_int32_0,zero_point2); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c index 32cd7aef3d..cdaf576172 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packa_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c index 539386f5d0..06a1c9ba52 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_packb_amd512vnni.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h index deb35e8e09..1e91381001 100644 --- a/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h +++ b/kernels/zen4/lpgemm/u8s8s32/lpgemm_s32_kern_macros.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -124,7 +124,7 @@ reg = _mm512_mask_mullo_epi32( reg, relu_cmp_mask, reg, selector2 ); \ // Downscale macro -#define CVT_MULRND_CVT32(reg,selector) \ +#define CVT_MULRND_CVT32(reg,selector,zero_point) \ reg = \ _mm512_cvtps_epi32 \ ( \ @@ -134,7 +134,8 @@ ( __m512 )selector, \ ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ ) \ - ) \ + ); \ + reg = _mm512_add_epi32( reg, _mm512_cvtepi8_epi32( zero_point ) ); \ // Downscale store macro #define CVT_STORE_S32_S8(reg,m_ind,n_ind) \ @@ -147,7 +148,7 @@ ) \ // Downscale n < 16 macro -#define CVT_MULRND_CVT32_LT16(reg,selector) \ +#define CVT_MULRND_CVT32_LT16(reg,selector,zero_point) \ reg = \ _mm512_cvtps_epi32 \ ( \ @@ -157,7 +158,8 @@ ( __m512 )selector, \ ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ ) \ - ) \ + ); \ + reg = _mm512_add_epi32( reg, _mm512_cvtepi8_epi32( zero_point ) ); \ /* TANH GeLU (x) = 0.5* x * (1 + tanh ( 0.797884 * ( x + ( 0.044715 * x^3 ) ) ) ) */ #define GELU_TANH_S32_AVX512(reg, y, r, r2, x, z, dn, x_tanh, q) \ diff --git a/ref_kernels/1/CMakeLists.txt b/ref_kernels/1/CMakeLists.txt deleted file mode 100644 index c279113758..0000000000 --- a/ref_kernels/1/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_addv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_aminv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_copyv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_invertv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scal2v_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_scalv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_setv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_subv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_swapv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_xpbyv_ref.c - ) diff --git a/ref_kernels/1f/CMakeLists.txt b/ref_kernels/1f/CMakeLists.txt deleted file mode 100644 index 1b54e5eb80..0000000000 --- a/ref_kernels/1f/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotaxpyv_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxf_ref.c - ) - diff --git a/ref_kernels/1m/CMakeLists.txt b/ref_kernels/1m/CMakeLists.txt deleted file mode 100644 index 34f15ae69f..0000000000 --- a/ref_kernels/1m/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_1er_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_3mis_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_4mi_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_bb_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_cxk_rih_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_unpackm_cxk_ref.c - ) diff --git a/ref_kernels/3/CMakeLists.txt b/ref_kernels/3/CMakeLists.txt deleted file mode 100644 index 3919189eb7..0000000000 --- a/ref_kernels/3/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmsup_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_ref.c - ) - -add_subdirectory(bb) diff --git a/ref_kernels/3/bb/CMakeLists.txt b/ref_kernels/3/bb/CMakeLists.txt deleted file mode 100644 index a3ce393621..0000000000 --- a/ref_kernels/3/bb/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmbb_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsmbb_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsmbb_ref.c - ) - diff --git a/ref_kernels/3/bli_gemmsup_ref.c b/ref_kernels/3/bli_gemmsup_ref.c index 0c3773c1c0..1d3303505f 100644 --- a/ref_kernels/3/bli_gemmsup_ref.c +++ b/ref_kernels/3/bli_gemmsup_ref.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/ref_kernels/CMakeLists.txt b/ref_kernels/CMakeLists.txt deleted file mode 100644 index d26bce06a5..0000000000 --- a/ref_kernels/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## - -if(${TARGET_ARCH} STREQUAL amdzen) -add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic) -add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_DIR}/ref_kernels/zen) -add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2) -add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ${CMAKE_BINARY_DIR}/ref_kernels/zen3) -add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen4 ${CMAKE_BINARY_DIR}/ref_kernels/zen4) -else() -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_ref.c - ) - -set(SUBDIRECTORIES "1" "1f" "1m" "3" "ind") - -#Add all subdirectories -foreach(VAR ${SUBDIRECTORIES}) - add_subdirectory(${VAR}) -endforeach() -endif() diff --git a/ref_kernels/bli_cntx_ref.c b/ref_kernels/bli_cntx_ref.c index 00acdfd08d..b0d47d26f1 100644 --- a/ref_kernels/bli_cntx_ref.c +++ b/ref_kernels/bli_cntx_ref.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -482,6 +482,30 @@ void GENBARNAME(cntx_init) // -- Set level-3 small/unpacked micro-kernels and preferences ------------- + // -- Set SUP blocksizes ------------------------------------------------------- + // These blocksizes are copied from native blocksizes for ref + + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 4, 4, 4, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 256, 128, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4096, 4096, 4096, 4096 ); + + // Initialize the context with the default blocksize objects and their + // multiples. + bli_cntx_set_l3_sup_blkszs + ( + 5, + // level-3 + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + funcs = bli_cntx_l3_sup_kers_buf( cntx ); mbools = bli_cntx_l3_sup_kers_prefs_buf( cntx ); @@ -529,6 +553,79 @@ void GENBARNAME(cntx_init) bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE ); + // -- Set level-3 small/unpacked micro-kernels, preferences and blocksizes + // for matrices dealing with triangular matrices------------- + +// -- Set blocksizes ------------------------------------------------------- + + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 0, 0, 0, 0 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 0, 0, 0, 0 ); + + // Initialize the context with the default blocksize objects and their + // multiples. + bli_cntx_set_l3_sup_tri_blkszs + ( + 5, + // level-3 + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + funcs = bli_cntx_l3_sup_tri_kers_buf( cntx ); + mbools = bli_cntx_l3_sup_tri_kers_prefs_buf( cntx ); + +#if 0 + // Adhere to the small/unpacked ukernel mappings: + // - rv -> rrr, rcr + // - rg -> rrc, rcc + // - cv -> ccr, ccc + // - cg -> crr, crc + gen_sup_func_init( &funcs[ BLIS_RRR ], + &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_RRC ], + &funcs[ BLIS_RCC ], gemmsup_rg_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CCR ], + &funcs[ BLIS_CCC ], gemmsup_cv_ukr_name ); + gen_sup_func_init( &funcs[ BLIS_CRR ], + &funcs[ BLIS_CRC ], gemmsup_cg_ukr_name ); +#endif + gen_func_init( &funcs[ BLIS_RRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_RCC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CRC ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCR ], gemmsup_rv_ukr_name ); + gen_func_init( &funcs[ BLIS_CCC ], gemmsup_rv_ukr_name ); + + // Register the general-stride/generic ukernel to the "catch-all" slot + // associated with the BLIS_XXX enum value. This slot will be queried if + // *any* operand is stored with general stride. + gen_func_init( &funcs[ BLIS_XXX ], gemmsup_gx_ukr_name ); + + + // Set the l3 sup ukernel storage preferences. + // s d c z + bli_mbool_init( &mbools[ BLIS_RRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_RCC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CRC ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCR ], TRUE, TRUE, TRUE, TRUE ); + bli_mbool_init( &mbools[ BLIS_CCC ], TRUE, TRUE, TRUE, TRUE ); + + bli_mbool_init( &mbools[ BLIS_XXX ], TRUE, TRUE, TRUE, TRUE ); + + // -- Set level-1f kernels ------------------------------------------------- funcs = bli_cntx_l1f_kers_buf( cntx ); @@ -607,10 +704,6 @@ void GENBARNAME(cntx_init) bli_cntx_set_schema_a_block( BLIS_PACKED_ROW_PANELS, cntx ); bli_cntx_set_schema_b_panel( BLIS_PACKED_COL_PANELS, cntx ); bli_cntx_set_schema_c_panel( BLIS_NOT_PACKED, cntx ); - - //bli_cntx_set_anti_pref( FALSE, cntx ); - - //bli_cntx_set_membrk( bli_membrk_query(), cntx ); } // ----------------------------------------------------------------------------- diff --git a/ref_kernels/ind/CMakeLists.txt b/ref_kernels/ind/CMakeLists.txt deleted file mode 100644 index 0a02584b1a..0000000000 --- a/ref_kernels/ind/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm1m_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm3m1_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm3mh_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm4m1_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm4mb_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm4mh_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm1m_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm3m1_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm4m1_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm1m_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm3m1_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm4m1_ref.c - ) diff --git a/sandbox/gemmlike/bli_sandbox.h b/sandbox/gemmlike/bli_sandbox.h index d6e6522e8c..f3782b3dbc 100644 --- a/sandbox/gemmlike/bli_sandbox.h +++ b/sandbox/gemmlike/bli_sandbox.h @@ -44,12 +44,15 @@ // made available to applications (or the framework) during compilation. #include "bls_gemm.h" +#include "bls_gemm_check.h" #include "bls_gemm_var.h" #include "bls_l3_packm_a.h" #include "bls_l3_packm_b.h" #include "bls_l3_packm_var.h" +#include "bls_packm_cxk.h" + #include "bls_l3_decor.h" diff --git a/sandbox/gemmlike/bls_gemm.c b/sandbox/gemmlike/bls_gemm.c index 3e4c9b2a33..4ee3a773f2 100644 --- a/sandbox/gemmlike/bls_gemm.c +++ b/sandbox/gemmlike/bls_gemm.c @@ -94,7 +94,7 @@ void bls_gemm_ex // Check parameters. if ( bli_error_checking_is_enabled() ) { - bli_gemm_check( alpha, a, b, beta, c, cntx ); + bls_gemm_check( alpha, a, b, beta, c, cntx ); } // If C has a zero dimension, return early. diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c index 330a94801b..ae695ce34f 100644 --- a/sandbox/gemmlike/bls_gemm_bp_var1.c +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -230,9 +230,6 @@ void PASTECH2(bls_,ch,varname) \ thrinfo_t* restrict thread_pa = NULL; \ thrinfo_t* restrict thread_jr = NULL; \ thrinfo_t* restrict thread_ir = NULL; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ \ /* Identify the current thrinfo_t node and then grow the tree. */ \ thread_jc = thread; \ diff --git a/sandbox/gemmlike/bls_gemm_bp_var2.c b/sandbox/gemmlike/bls_gemm_bp_var2.c index 22df767aea..957cd57944 100644 --- a/sandbox/gemmlike/bls_gemm_bp_var2.c +++ b/sandbox/gemmlike/bls_gemm_bp_var2.c @@ -538,12 +538,6 @@ void PASTECH2(bls_,ch,varname) \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype zero = *PASTEMAC(ch,0); \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. - NOTE: This initialization should really be done statically since - var2 executes this microkernel wrapper many times, and the overhead - of touching the temporary microtile adds up. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ \ /* Handle interior and edge cases separately. */ \ if ( mr_cur == MR && nr_cur == NR ) \ diff --git a/sandbox/gemmlike/bls_gemm_check.c b/sandbox/gemmlike/bls_gemm_check.c new file mode 100644 index 0000000000..bd6c2647e2 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_check.c @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bls_gemm_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + //bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); + + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( c ); + bli_check_error_code( e_val ); + + // Check scalar/vector/matrix type. + + e_val = bli_check_scalar_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( c ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( c ); + bli_check_error_code( e_val ); + + // Check for sufficiently sized stack buffers + + e_val = bli_check_sufficient_stack_buf_size( bli_obj_dt( a ), cntx ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); + + // Check for consistent datatypes. + // NOTE: We only perform these tests when mixed datatype support is + // disabled. + + e_val = bli_check_consistent_object_datatypes( c, a ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, b ); + bli_check_error_code( e_val ); +} + diff --git a/sandbox/gemmlike/bls_gemm_check.h b/sandbox/gemmlike/bls_gemm_check.h new file mode 100644 index 0000000000..8b97069911 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_check.h @@ -0,0 +1,49 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based check functions. +// + +void bls_gemm_check + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ); + diff --git a/sandbox/gemmlike/bls_l3_packm_a.c b/sandbox/gemmlike/bls_l3_packm_a.c index c55a19c7b7..0dcc531fdb 100644 --- a/sandbox/gemmlike/bls_l3_packm_a.c +++ b/sandbox/gemmlike/bls_l3_packm_a.c @@ -67,7 +67,7 @@ void PASTECH2(bls_,ch,opname) \ siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ \ /* Check the mem_t entry provided by the caller. If it is unallocated, - then we need to acquire a block from the memory broker. */ \ + then we need to acquire a block from the packed block allocator. */ \ if ( bli_mem_is_unalloc( mem ) ) \ { \ if ( bli_thread_am_ochief( thread ) ) \ @@ -79,7 +79,7 @@ void PASTECH2(bls_,ch,opname) \ the current function before the other threads have a chance to copy from it. (A barrier would fix that race condition, but then again, I prefer to keep barriers to a minimum.) */ \ - bli_membrk_acquire_m \ + bli_pba_acquire_m \ ( \ rntm, \ size_needed, \ @@ -104,8 +104,8 @@ void PASTECH2(bls_,ch,opname) \ else /* if ( bli_mem_is_alloc( mem ) ) */ \ { \ /* If the mem_t entry provided by the caller does NOT contain a NULL - buffer, then a block has already been acquired from the memory - broker and cached by the caller. */ \ + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ \ /* As a sanity check, we should make sure that the mem_t object isn't associated with a block that is too small compared to the size of @@ -123,12 +123,12 @@ void PASTECH2(bls_,ch,opname) \ above for why the acquisition needs to be directly to the chief thread's passed-in mem_t and not a local (temporary) mem_t. */ \ - bli_membrk_release \ + bli_pba_release \ ( \ rntm, \ mem \ ); \ - bli_membrk_acquire_m \ + bli_pba_acquire_m \ ( \ rntm, \ size_needed, \ @@ -182,7 +182,7 @@ void PASTECH2(bls_,ch,opname) \ is allocated, which it should be. */ \ if ( bli_mem_is_alloc( mem ) ) \ { \ - bli_membrk_release \ + bli_pba_release \ ( \ rntm, \ mem \ diff --git a/sandbox/gemmlike/bls_l3_packm_b.c b/sandbox/gemmlike/bls_l3_packm_b.c index cae93df012..9d563109a6 100644 --- a/sandbox/gemmlike/bls_l3_packm_b.c +++ b/sandbox/gemmlike/bls_l3_packm_b.c @@ -67,7 +67,7 @@ void PASTECH2(bls_,ch,opname) \ siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ \ /* Check the mem_t entry provided by the caller. If it is unallocated, - then we need to acquire a block from the memory broker. */ \ + then we need to acquire a block from the packed block allocator. */ \ if ( bli_mem_is_unalloc( mem ) ) \ { \ if ( bli_thread_am_ochief( thread ) ) \ @@ -79,7 +79,7 @@ void PASTECH2(bls_,ch,opname) \ the current function before the other threads have a chance to copy from it. (A barrier would fix that race condition, but then again, I prefer to keep barriers to a minimum.) */ \ - bli_membrk_acquire_m \ + bli_pba_acquire_m \ ( \ rntm, \ size_needed, \ @@ -104,8 +104,8 @@ void PASTECH2(bls_,ch,opname) \ else /* if ( bli_mem_is_alloc( mem ) ) */ \ { \ /* If the mem_t entry provided by the caller does NOT contain a NULL - buffer, then a block has already been acquired from the memory - broker and cached by the caller. */ \ + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ \ /* As a sanity check, we should make sure that the mem_t object isn't associated with a block that is too small compared to the size of @@ -123,12 +123,12 @@ void PASTECH2(bls_,ch,opname) \ above for why the acquisition needs to be directly to the chief thread's passed-in mem_t and not a local (temporary) mem_t. */ \ - bli_membrk_release \ + bli_pba_release \ ( \ rntm, \ mem \ ); \ - bli_membrk_acquire_m \ + bli_pba_acquire_m \ ( \ rntm, \ size_needed, \ @@ -182,7 +182,7 @@ void PASTECH2(bls_,ch,opname) \ is allocated, which it should be. */ \ if ( bli_mem_is_alloc( mem ) ) \ { \ - bli_membrk_release \ + bli_pba_release \ ( \ rntm, \ mem \ diff --git a/sandbox/gemmlike/bls_l3_packm_var.h b/sandbox/gemmlike/bls_l3_packm_var.h index 0e8eb9ee8a..98300536bc 100644 --- a/sandbox/gemmlike/bls_l3_packm_var.h +++ b/sandbox/gemmlike/bls_l3_packm_var.h @@ -61,3 +61,14 @@ GENTPROT( double, d, packm_var1 ) GENTPROT( scomplex, c, packm_var1 ) GENTPROT( dcomplex, z, packm_var1 ) +//INSERT_GENTPROT_BASIC0( packm_var2 ) +GENTPROT( float, s, packm_var2 ) +GENTPROT( double, d, packm_var2 ) +GENTPROT( scomplex, c, packm_var2 ) +GENTPROT( dcomplex, z, packm_var2 ) + +//INSERT_GENTPROT_BASIC0( packm_var3 ) +GENTPROT( float, s, packm_var3 ) +GENTPROT( double, d, packm_var3 ) +GENTPROT( scomplex, c, packm_var3 ) +GENTPROT( dcomplex, z, packm_var3 ) diff --git a/sandbox/gemmlike/bls_l3_packm_var.c b/sandbox/gemmlike/bls_l3_packm_var1.c similarity index 90% rename from sandbox/gemmlike/bls_l3_packm_var.c rename to sandbox/gemmlike/bls_l3_packm_var1.c index 3265ef834d..c0649a9ec4 100644 --- a/sandbox/gemmlike/bls_l3_packm_var.c +++ b/sandbox/gemmlike/bls_l3_packm_var1.c @@ -35,7 +35,7 @@ #include "blis.h" // -// Define BLAS-like interfaces to the variants. +// Variant 1 provides basic support for packing by calling packm_cxk(). // #undef GENTFUNC @@ -66,13 +66,11 @@ void PASTECH2(bls_,ch,varname) \ dim_t it, ic; \ dim_t ic0; \ doff_t ic_inc; \ - dim_t panel_len_full; \ - dim_t panel_len_i; \ + dim_t panel_len; \ dim_t panel_len_max; \ - dim_t panel_len_max_i; \ - dim_t panel_dim_i; \ + dim_t panel_dim; \ dim_t panel_dim_max; \ - inc_t vs_c; \ + inc_t incc; \ inc_t ldc; \ inc_t ldp; \ conj_t conjc; \ @@ -95,10 +93,10 @@ void PASTECH2(bls_,ch,varname) \ { \ /* Prepare to pack to row-stored column panels. */ \ iter_dim = n; \ - panel_len_full = m; \ + panel_len = m; \ panel_len_max = m_max; \ panel_dim_max = pd_p; \ - vs_c = cs_c; \ + incc = cs_c; \ ldc = rs_c; \ ldp = rs_p; \ } \ @@ -106,10 +104,10 @@ void PASTECH2(bls_,ch,varname) \ { \ /* Prepare to pack to column-stored row panels. */ \ iter_dim = m; \ - panel_len_full = n; \ + panel_len = n; \ panel_len_max = n_max; \ panel_dim_max = pd_p; \ - vs_c = rs_c; \ + incc = rs_c; \ ldc = cs_c; \ ldp = cs_p; \ } \ @@ -147,31 +145,28 @@ void PASTECH2(bls_,ch,varname) \ for ( ic = ic0, it = 0; it < n_iter; \ ic += ic_inc, it += 1 ) \ { \ - panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ \ - ctype* restrict c_begin = c_cast + (ic )*vs_c; \ + ctype* restrict c_begin = c_cast + (ic )*incc; \ \ ctype* restrict c_use = c_begin; \ ctype* restrict p_use = p_begin; \ -\ - panel_len_i = panel_len_full; \ - panel_len_max_i = panel_len_max; \ \ /* The definition of bli_packm_my_iter() will depend on whether slab or round-robin partitioning was requested at configure-time. (The default is slab.) */ \ if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ { \ - PASTEMAC(ch,packm_cxk) \ + PASTECH2(bls_,ch,packm_cxk) \ ( \ conjc, \ schema, \ - panel_dim_i, \ + panel_dim, \ panel_dim_max, \ - panel_len_i, \ - panel_len_max_i, \ + panel_len, \ + panel_len_max, \ kappa_cast, \ - c_use, vs_c, ldc, \ + c_use, incc, ldc, \ p_use, ldp, \ cntx \ ); \ diff --git a/sandbox/gemmlike/bls_l3_packm_var2.c b/sandbox/gemmlike/bls_l3_packm_var2.c new file mode 100644 index 0000000000..8d2b90cac1 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var2.c @@ -0,0 +1,244 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 2 is similar to variant 1, but inlines the contents of packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conjc into account. */ \ + if ( bli_is_conj( conjc ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* cli = c_use + (l )*ldc + (i )*incc; \ + ctype* pli = p_use + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *cli, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* cli = c_use + (l )*ldc + (i )*incc; \ + ctype* pli = p_use + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *cli, *pli ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p_use + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p_use + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var2 ) +GENTFUNC( double, d, packm_var2 ) +GENTFUNC( scomplex, c, packm_var2 ) +GENTFUNC( dcomplex, z, packm_var2 ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var3.c b/sandbox/gemmlike/bls_l3_packm_var3.c new file mode 100644 index 0000000000..5ea80ff424 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var3.c @@ -0,0 +1,200 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 3 is similar to variant 1, except that it parallelizes packing +// along the k dimension. (Our current hypothesis is that this method of +// parallelizing the operation may perform better on some NUMA systems.) +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t pr_start, pr_end; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. */ \ + bli_thread_range_sub( thread, panel_len, 1, FALSE, &pr_start, &pr_end ); \ +\ + /* Define instances of panel_len and panel_len_max that are specific to + the local thread. */ \ + dim_t panel_len_loc = pr_end - pr_start; \ + dim_t panel_len_max_loc = panel_len_loc; \ +\ + /* If panel_len_max > panel_len, then there are some columns in p that + need to be zeroed. Of course, only the last thread will be responsible + for this edge region. */ \ + dim_t panel_len_zero = panel_len_max - panel_len; \ + if ( tid == nt - 1 ) panel_len_max_loc += panel_len_zero; \ +\ + /* Shift the pointer for c and p to the appropriate locations within the + first micropanel. */ \ + dim_t off_loc = pr_start; \ + ctype* restrict c_begin_loc = c_cast + off_loc * ldc; \ + ctype* restrict p_begin_loc = p_cast + off_loc * ldp; \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_use = c_begin_loc + (ic )*incc; \ + ctype* restrict p_use = p_begin_loc + (it )*ps_p; \ +\ + { \ + PASTECH2(bls_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len_loc, \ + panel_len_max_loc, \ + kappa_cast, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var3 ) +GENTFUNC( float, s, packm_var3 ) +GENTFUNC( double, d, packm_var3 ) +GENTFUNC( scomplex, c, packm_var3 ) +GENTFUNC( dcomplex, z, packm_var3 ) + +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var3: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var3: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ + diff --git a/sandbox/gemmlike/bls_packm_cxk.c b/sandbox/gemmlike/bls_packm_cxk.c new file mode 100644 index 0000000000..ca11c207c0 --- /dev/null +++ b/sandbox/gemmlike/bls_packm_cxk.c @@ -0,0 +1,161 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ) \ +{ \ + /* Note that we use panel_dim_max, not panel_dim, to query the packm + kernel function pointer. This means that we always use the same + kernel, even for edge cases. */ \ + num_t dt = PASTEMAC(ch,type); \ + l1mkr_t ker_id = panel_dim_max; \ +\ + PASTECH2(ch,opname,_ker_ft) f; \ +\ + /* Query the context for the packm kernel corresponding to the current + panel dimension, or kernel id. If the id is invalid, the function will + return NULL. */ \ + f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ +\ + /* If there exists a kernel implementation for the micro-panel dimension + provided, we invoke the implementation. Otherwise, we use scal2m. */ \ + /* NOTE: We've disabled calling packm micro-kernels from the context for + this implementation. To re-enable, change FALSE to TRUE in the + conditional below. */ \ + if ( f != NULL && FALSE ) \ + { \ + f \ + ( \ + conja, \ + schema, \ + panel_dim, \ + panel_len, \ + panel_len_max, \ + kappa, \ + a, inca, lda, \ + p, ldp, \ + cntx \ + ); \ + } \ + else \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *ali, *pli ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_cxk ) +GENTFUNC( float, s, packm_cxk ) +GENTFUNC( double, d, packm_cxk ) +GENTFUNC( scomplex, c, packm_cxk ) +GENTFUNC( dcomplex, z, packm_cxk ) + diff --git a/sandbox/power10/gemm_pack.h b/sandbox/gemmlike/bls_packm_cxk.h similarity index 70% rename from sandbox/power10/gemm_pack.h rename to sandbox/gemmlike/bls_packm_cxk.h index 89a81d7683..f6582d64a7 100644 --- a/sandbox/power10/gemm_pack.h +++ b/sandbox/gemmlike/bls_packm_cxk.h @@ -32,33 +32,27 @@ */ -// Templates for packing routines prototypes -#include "bli_sandbox.h" - -#define PACK_FUNC_NAME_(ch, mat) ch ## _pack ## mat -#define PACK_FUNC_NAME(ch, mat) PACK_FUNC_NAME_(ch, mat) - -#define PACK_MACRO_PROTO(ch, DTYPE_IN) \ -\ -void PACK_FUNC_NAME(ch, A) \ - ( \ - dim_t MR, \ - int m, int k, \ - DTYPE_IN* ap, int rs_a, int cs_a, \ - DTYPE_IN* apack \ - ); \ +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ \ -void PACK_FUNC_NAME(ch, B) \ - ( \ - dim_t NR, \ - int k, int n, \ - DTYPE_IN* bp, int rs_b, int cs_b, \ - DTYPE_IN* bpack \ - ); +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ); + +//INSERT_GENTPROT_BASIC0( packm_cxk ) +GENTPROT( float, s, packm_cxk ) +GENTPROT( double, d, packm_cxk ) +GENTPROT( scomplex, c, packm_cxk ) +GENTPROT( dcomplex, z, packm_cxk ) -PACK_MACRO_PROTO(sb, bfloat16) -PACK_MACRO_PROTO(sh, float16) -PACK_MACRO_PROTO(i16, int16_t) -PACK_MACRO_PROTO(i8, int8_t) -PACK_MACRO_PROTO(i4, nibbles) diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.c b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c index 851a29e52b..0086a48e8f 100644 --- a/sandbox/gemmlike/thread/bls_l3_decor_openmp.c +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c @@ -75,7 +75,7 @@ void bls_l3_thread_decorator // Set the packing block allocator field of the rntm. This will be // inherited by all of the child threads when they make local copies of // the rntm below. - bli_membrk_rntm_set_membrk( rntm ); + bli_pba_rntm_set_pba( rntm ); // Allcoate a global communicator for the root thrinfo_t structures. thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); @@ -92,7 +92,7 @@ void bls_l3_thread_decorator // Query the thread's id from OpenMP. const dim_t tid = omp_get_thread_num(); - // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // Check for a somewhat obscure OpenMP thread-mismatch issue. // NOTE: This calls the same function used for the conventional/large // code path. bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c index f87d79fd6c..ff723a4ce4 100644 --- a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c @@ -121,6 +121,8 @@ void bls_l3_thread_decorator rntm_t* rntm ) { + err_t r_val; + // Query the total number of threads from the context. const dim_t n_threads = bli_rntm_num_threads( rntm ); @@ -140,7 +142,7 @@ void bls_l3_thread_decorator // Set the packing block allocator field of the rntm. This will be // inherited by all of the child threads when they make local copies of // the rntm below. - bli_membrk_rntm_set_membrk( rntm ); + bli_pba_rntm_set_pba( rntm ); // Allocate a global communicator for the root thrinfo_t structures. thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); @@ -151,12 +153,12 @@ void bls_l3_thread_decorator #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_l3_thread_decorator().pth: " ); #endif - bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads ); + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_l3_thread_decorator().pth: " ); #endif - thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); // NOTE: We must iterate backwards so that the chief thread (thread id 0) // can spawn all other threads before proceeding with its own computation. diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.c b/sandbox/gemmlike/thread/bls_l3_decor_single.c index 7d9017dcd5..8bb04817fb 100644 --- a/sandbox/gemmlike/thread/bls_l3_decor_single.c +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.c @@ -68,7 +68,7 @@ void bls_l3_thread_decorator bli_sba_rntm_set_pool( 0, array, rntm ); // Set the packing block allocator field of the rntm. - bli_membrk_rntm_set_membrk( rntm ); + bli_pba_rntm_set_pba( rntm ); #ifndef SKIP_THRINFO_TREE // Allcoate a global communicator for the root thrinfo_t structures. diff --git a/sandbox/power10/POWER10.md b/sandbox/power10/POWER10.md index a5071159f3..a6d7d65c3b 100644 --- a/sandbox/power10/POWER10.md +++ b/sandbox/power10/POWER10.md @@ -1,24 +1,20 @@ ### Low Precision POWER10 Kernels -This is a special BLIS Sandbox that allows users to call low precision POWER10 `gemm` kernels. +This is a special BLIS Sandbox that allows users to call POWER10 reduced precision/integer `GEMM` kernels. + +Supported kernels: `IEEE float16 (bli_shgemm), bfloat16 (bli_sbgemm), int16 (bli_i16gemm), int8 (bli_i8gemm), int4 (bli_i4gemm)`. #### Introduction -This document describes how the low precision POWER10 `gemm` kernels are implemented. The document will also demonstrate how to call the `gemm` kernels. +This document describes how the low precision POWER10 `gemm` kernels are implemented and explains how to call the POWER10 `GEMM` kernels. -**Important: This sandbox does not have the full functionality of BLIS. This sandbox can only perform single threaded, no transpose, GEMM. At this time, full functioning POWER10 hardware has not be released. Once hardware has been released, the kernels will be further optimized in areas such as prefetching and cache blocksizes.** +**Important: These kernels does not have the full functionality of BLIS. The kernels can only perform single threaded, no transpose, GEMM.** #### Implementation -The kernels are implemented in `generic_gemm.c`. They are instantiated with macro templates. The main template is called `GENERIC_GEMM`. This template is used to create the 5-loop `gemm` function. - -The API points are created in `gemm_api.c`. In this file, the API points are wrappers for the functions that are created by the templates in `generic_gemm.c`. - -#### Kernels - -The following low precision datatypes have POWER10 `gemm` kernels: `IEEE float16, bfloat16, int16, int8, int4`. +The kernels are implemented in `gemm.c`. They are instantiated with macro templates. The main template is called `GENERIC_GEMM`. This template is used to create the 5-loop `gemm` function. -#### Low Precision Types +#### Reduced precision/integer Types | BLIS type | BLIS char | Type definition | Used to represent... | |:-----------|:----------|:---------------------------------------|:-------------------------------------| @@ -28,9 +24,9 @@ The following low precision datatypes have POWER10 `gemm` kernels: `IEEE float16 | `int8` | `i8` | `int8_t` | 8 bit integers | | `int4` | `i4` | `typedef union{ uint8_t v; struct { uint8_t nib1:4; uint8_t nib2:4; } bits; }` | 4 bit integers | -#### Low Precision API +#### Reduced Precision/Integer API -The API that is used for the low precision POWER10 `gemm` kernels is similar to the existing [BLIS basic typed API](https://github.com/flame/blis/blob/master/docs/BLISTypedAPI.md). The main difference between the two is that in the existing BLIS typed API, there is only one type for the input and output matrices. However in the low precision API, there is a input and output type. +The API that is used for the reduced precision/integer POWER10 `GEMM` kernels is similar to the existing [BLIS basic typed API](https://github.com/flame/blis/blob/master/docs/BLISTypedAPI.md). The main difference is the POWER10 kernels expect two types: `ctype_in` and `ctype_out`. Thus the new `gemm` call looks like the following: @@ -50,10 +46,7 @@ void bli_??gemm ); ``` -The first `?` is for the output type. The second `?` is for the input type. - -At this time for IEEE float16 and bfloat16, the only output type is single precision float. For int16, int8, and int4, the only output type is 32 bit int. - +`??` is meant to replaced with the kernel prefix. #### How To Build The Sandbox @@ -64,6 +57,9 @@ Add the following flags when running the configure script to build BLIS correctl Ensure that you have GCC 10.2 or greater. +#### P10 Testsuite + +In `p10_testsuite`, there are performance gathering and correctness checking programs for the POWER10 reduced precision/integer `GEMM` kernels. By default, the performance gathering and correctness checking is done over square matrices ranging from 80 to 4000 in increments of 80. Performance is measured in GFLOPs, and correctness is measured using the BLIS method (detailed in `blis/testsuite/test_gemm.c`). #### References diff --git a/sandbox/power10/bli_sandbox.h b/sandbox/power10/bli_sandbox.h index 77c5fe2cb5..22d293d130 100644 --- a/sandbox/power10/bli_sandbox.h +++ b/sandbox/power10/bli_sandbox.h @@ -36,14 +36,12 @@ #define BLIS_SANDBOX_H #include "blis.h" -#include "gemm_api.h" +#include "gemm_prototypes.h" // NOTE: This header is the only header required to be present in the sandbox // implementation directory. -// This header is used to create the typedefs needed for low precision - -// int4 type +// int4 typedef union { uint8_t v; @@ -54,7 +52,7 @@ typedef union } bits; } nibbles; -// bfloat16 +// brain float16 typedef union { uint16_t v; @@ -80,36 +78,25 @@ typedef union #define P10_PG_SIZE 4096 +// microkernel prototypes GEMM_UKR_PROT2( bfloat16, float, sb, gemm_power10_mma_8x16 ) GEMM_UKR_PROT2( float16, float, sh, gemm_power10_mma_8x16 ) GEMM_UKR_PROT2( int16_t, int32_t, i16, gemm_power10_mma_8x16 ) GEMM_UKR_PROT2( int8_t, int32_t, i8, gemm_power10_mma_8x16 ) GEMM_UKR_PROT2( nibbles, int32_t, i4, gemm_power10_mma_8x16 ) -/* Creates a function that initializes a matrix of type ctype with random vals */ -#define RandomMatrixMacro(ch, ctype, rand_func) \ - RM_PROT(ch, ctype) \ - { \ - for ( int i=0; i0)), // innermost loop iterations + i8_pack_a, // pack kernel for A + i8_pack_b, // pack kernel for B + bli_i8gemm_power10_mma_8x16, // microkernel function name + 4, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 6656, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + +GENERIC_GEMM( + i4, // kernel name prefix + nibbles, // input type + int, // output type + (pb/8 + (pb%8>0)), // innermost loop iterations + i4_pack_a, // pack kernel for A + i4_pack_b, // pack kernel for B + bli_i4gemm_power10_mma_8x16, // microkernel function name + 8, // K_MMA + 8, // MR + 16, // NR + 384, // MC + 6656, // KC + 4096, // NC + 0, // A_ALIGN + 0 // B_ALIGN +); + diff --git a/sandbox/power10/gemm_pack.c b/sandbox/power10/gemm_pack.c deleted file mode 100644 index 3834b6d7ce..0000000000 --- a/sandbox/power10/gemm_pack.c +++ /dev/null @@ -1,889 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -// Templates for different packing routine - -#include "gemm_pack.h" - -/* - - Details on bit16_dt vector data structure - - Vector X = [ X[0,0] X[0,1] X[1,0] X[1,1] X[2,0] X[2,1] X[3,0] X[3,1] ] - Vector Y = [ Y[0,0] Y[0,1] Y[1,0] Y[1,1] Y[2,0] Y[2,1] Y[3,0] Y[3,1] ] - - These bit16_dt vectors represent a 4x2 matrix. Hence, in matrix form it - looks like the following: - - X = [ X[0,0] X[0,1] - X[1,0] X[1,1] - X[2,0] X[2,1] - X[3,0] X[3,1] ] - - The outer product instruction: xvbf16ger2 (bfloat16 outer product) - - Syntax: - - xvbf16ger2 ACCUMULATOR A, VECTOR X, VECTOR Y - - Semantics: - - A = X * Y^T - - The generic packing routine would load 8 elements from the same column. - This causes an issue since the instruction expects the vector to be a - 4x2 matrix where the data is packed in contiguous order. Thus, we must make - a packing routine that will interleave the matrix data. Making it so - that when we load the 8 contiguous elements from A, it will represent - a 4x2 section of the matrix. - -*/ - -#define k_even_apack_16(ir) \ - *adest++ = ap[ (i+ir)*rs_a + p_idx*cs_a ]; \ - *adest++ = ap[ (i+ir)*rs_a + (p_idx+1)*cs_a ]; - -#define k_odd_apack_16(ir) \ - *adest++ = ap[ (i+ir)*rs_a + (k-1)*cs_a ]; \ - memset(adest, 0, 2); \ - adest++; - -#define pad_macro_16(dest_matrix) \ - memset(dest_matrix, 0, 4); \ - dest_matrix+=2; - -#define BIT16_PACK_A(ch, DTYPE_IN) \ -\ -void PACK_FUNC_NAME(ch, A) \ - ( \ - dim_t MR, \ - int m, int k, \ - DTYPE_IN* ap, int rs_a, int cs_a, \ - DTYPE_IN* apack \ - ) \ -{ \ - int k_odd = k%2; \ - int p_idx; \ -\ - DTYPE_IN* adest = apack; \ - for (int i=0; i0)), 4, bli_i8gemm_power10_mma_8x16); -GENERIC_GEMM( i4, nibbles, int, (pb/8 + (pb%8>0)), 8, bli_i4gemm_power10_mma_8x16); diff --git a/sandbox/power10/p10_testsuite/Makefile b/sandbox/power10/p10_testsuite/Makefile new file mode 100644 index 0000000000..b8a72c90cf --- /dev/null +++ b/sandbox/power10/p10_testsuite/Makefile @@ -0,0 +1,31 @@ +BLIS_PATH := ../../.. + +BLIS_INC := $(BLIS_PATH)/include/power10 +BLIS_LIB := $(BLIS_PATH)/lib/power10/libblis.a + +CC := gcc +LINKER := $(CC) + +CFLAGS := -I $(BLIS_INC) +LDFLAGS := -lpthread -lm + +OBJS := $(patsubst %.c,%.o, $(wildcard *.c)) +PERF_OBJS := performance.o +COR_OBJS := correctness.o cast_funcs.o + +all: performance correctness + +$(OBJS): %.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +performance: $(PERF_OBJS) + $(LINKER) $(PERF_OBJS) $(BLIS_LIB) -o ./gather_perf.x $(LDFLAGS) + +correctness: $(COR_OBJS) + $(LINKER) $(COR_OBJS) $(BLIS_LIB) -o ./test_correctness.x $(LDFLAGS) + +csv_clean: + rm -rf *.csv + +clean: + rm -rf *.x *.o diff --git a/sandbox/power10/p10_testsuite/cast_funcs.c b/sandbox/power10/p10_testsuite/cast_funcs.c new file mode 100644 index 0000000000..8108602c53 --- /dev/null +++ b/sandbox/power10/p10_testsuite/cast_funcs.c @@ -0,0 +1,180 @@ +#include "cast_funcs.h" +#include "../bli_sandbox.h" + +// bit map used for casting float to bfloat16 +typedef union +{ + float v; + struct + { + uint32_t m:23; + uint32_t e:8; + uint32_t s:1; + } bits; +} float32_s; + + +// cast float16 into float +float cast_f16_to_f32(float16 val) +{ + uint16_t in = val.v; + float out; + uint32_t t1; + uint32_t t2; + uint32_t t3; + + t1 = in & 0x7fff; // Non-sign bits + t2 = in & 0x8000; // Sign bit + t3 = in & 0x7c00; // Exponent + + t1 <<= 13; // Align mantissa on MSB + t2 <<= 16; // Shift sign bit into position + + t1 += 0x38000000; // Adjust bias + + t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero + + t1 |= t2; // Re-insert sign bit + + *((uint32_t*)&out) = t1; + return out; +} + +// cast float to float16 +float16 cast_f32_to_f16(const float in) +{ + float16 f16_out; + + uint32_t inu = *((uint32_t*)&in); + uint32_t t1; + uint32_t t2; + uint32_t t3; + + t1 = inu & 0x7fffffff; // Non-sign bits + t2 = inu & 0x80000000; // Sign bit + t3 = inu & 0x7f800000; // Exponent + + t1 >>= 13; // Align mantissa on MSB + t2 >>= 16; // Shift sign bit into position + + t1 -= 0x1c000; // Adjust bias + + t1 = (t3 < 0x38800000) ? 0 : t1; + t1 = (t3 > 0x47000000) ? 0x7bff : t1; + t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero + + t1 |= t2; // Re-insert sign bit + + f16_out.v = t1; + return f16_out; +} + + +// cast float to bfloat16 +bfloat16 cast_f32_to_bf16 (float val) +{ + bfloat16 bf16; + float32_s f32; + f32.v = val; + bf16.bits.s = f32.bits.s; + bf16.bits.e = f32.bits.e; + bf16.bits.m = f32.bits.m >> 16; + return bf16; +} + +// cast bfloat16 to float +float cast_bf16_to_f32(bfloat16 val) +{ + float32_s f32; + f32.bits.s = val.bits.s; + f32.bits.e = val.bits.e; + f32.bits.m = val.bits.m << 16; + return f32.v; +} + +// cast a nibbles struct to a float array +void cast_i4_to_f32(float *fvals, nibbles vals) +{ + int8_t val0 = vals.bits.nib1; + int8_t val1 = vals.bits.nib2; + + val0 = (val0 >= 8 ? val0 - 16 : val0); + val1 = (val1 >= 8 ? val1 - 16 : val1); + + fvals[0] = (float) val0; + fvals[1] = (float) val1; +} + +// condense two float vals to a nibbles struct +nibbles cast_f32_to_i4(float val0, float val1) +{ + nibbles vals; + + int8_t val0_ = ((int8_t)val0) & 0xf0; + int8_t val1_ = ((int8_t)val1) & 0xf0; + + vals.bits.nib1 = val0_; + vals.bits.nib2 = val1_; + + return vals; +} + +// cast float matrix to float nibbles +void cast_f32_to_i4m(float *a_float, nibbles *a, int num_elems) +{ + int j=0; + for(int i=0; i +// print kernel name +const char* get_kernel_name(int kernel_id) +{ + switch (kernel_id) + { + case FLOAT16 : return "bli_shgemm"; + case BFLOAT16: return "bli_sbgemm"; + case INT16 : return "bli_i16gemm"; + case INT8 : return "bli_i8gemm"; + case INT4 : return "bli_i4gemm"; + default: printf("INCORRECT KERNEL ID\n"); exit(-1); + } +} + +// normalize the vector using the forbenious norm +void normalize_vec(float *t, int n) +{ + // normalize t + float norm_factor; + bli_snormfv(n, t, 1, &norm_factor); + // round up to closest power of 2 + norm_factor = 1 / (pow( 2.0, ceil( log2( norm_factor ) ) )); + bli_sscalv(BLIS_NO_CONJUGATE, n, &norm_factor, t, 1); +} + + // Pre-conditions: + // - a is randomized. + // - b is randomized. + // - c_orig is randomized. + // Note: + // - alpha and beta should have non-zero imaginary components in the + // complex cases in order to more fully exercise the implementation. + // + // Under these conditions, we assume that the implementation for + // + // C := beta * C_orig + alpha * transa(A) * transb(B) + // + // is functioning correctly if + // + // normfv( v - z ) + // + // is negligible, where + // + // v = C * t + // z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t + // = beta * C_orig * t + alpha * transa(A) * transb(B) * t + // = beta * C_orig * t + alpha * transa(A) * w + // = beta * C_orig * t + z +float get_resid( + int m, int n, int k, + float *a, int rsa, int csa, + float *b, int rsb, int csb, + float *c, int rsc, int csc, + float *c_orig, + float *alpha, float *beta +) +{ + + float t[n], v[m], w[k], z[m]; + float one = 1.0, zero = 0.0; + + bli_srandv(n, t, 1); + + // normalize so that the values are at the same precision of the input values + normalize_vec(t, n); + + // v = C * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + n, + &one, + c, rsc, csc, + t, 1, + &zero, + v, 1 + ); + + // w = B * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + k, + n, + &one, + b, rsb, csb, + t, 1, + &zero, + w, 1 + ); + + // z = alpha * A * w + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + k, + alpha, + a, rsa, csa, + w, 1, + &zero, + z, 1 + ); + + // z += beta * C_orig * t + bli_sgemv( + BLIS_NO_TRANSPOSE, + BLIS_NO_CONJUGATE, + m, + n, + beta, + c_orig, rsc, csc, + t, 1, + &one, + z, 1 + ); + + // v = v - z + bli_ssubv ( + BLIS_NO_CONJUGATE, + m, + z, 1, + v, 1 + ); + + // norm = normfv(v) + float norm; + bli_snormfv ( + m, + v, 1, + &norm + ); + + return norm; +} + + +// test to see if the result from a BLIS GEMM kernel is correct for a given m x n x k mat-mul +// assumes the matrices are of type float +// assumes the matrices were randomized and normalized +void correctness_checker( + int m, int n, int k, + float *a, int rsa, int csa, + float *b, int rsb, int csb, + float *c_orig, int rsc, int csc, + float *c_ans, + float alpha, float beta +) +{ + double start, end; + + start = bli_clock(); + float resid = get_resid ( + m, n, k, + a, rsa, csa, + b, rsb, csb, + c_ans, rsc, csc, + c_orig, + &alpha, &beta + ); + end = bli_clock(); + + printf("%d, %d, %d, %8.4le\n", m,n,k, resid); +} + + +// create all the correctness checking functions for each kernel +GEN_FP_COR_KERNEL(sb, bli_sbgemm, bfloat16, cast_f32_to_bf16m, cast_bf16_to_f32m); +GEN_FP_COR_KERNEL(sh, bli_shgemm, float16, cast_f32_to_f16m, cast_f16_to_f32m); +GEN_I_COR_KERNEL(i16, bli_i16gemm, int16_t, cast_f32_to_i16m, cast_i16_to_f32m); +GEN_I_COR_KERNEL(i8, bli_i8gemm, int8_t, cast_f32_to_i8m, cast_i8_to_f32m); + +// correctness template for int types +void i4correctness_kernel (int m, int n, int k) +{ + if(n%2 != 0) + { + printf("int4 can't handle odd sizes in the data-order dimension"); + exit(-1); + } + + int rsa = k, csa = 1, + rsb = n, csb = 1, + rsc = n, csc = 1; + + nibbles *a, *b; + + int32_t *c_ans, *c_orig, alpha, beta; + + float *a_float, *b_float, + *c_ans_float, *c_orig_float; + + /* buffers that will be passed into the kernel */ + // int4 buffers only need half the space to store all the elements + a = (nibbles *) malloc (m * (k/2) * sizeof(nibbles)); + b = (nibbles *) malloc (k * (n/2) * sizeof(nibbles)); + + c_ans = (int32_t *) malloc (m * n * sizeof(int32_t)); + c_orig = (int32_t *) malloc (m * n * sizeof(int32_t)); + + /* std format buffers that will be used by the correctness checker */ + a_float = (float *) malloc (m * k * sizeof(float)); + b_float = (float *) malloc (k * n * sizeof(float)); + c_ans_float = (float *) malloc (m * n * sizeof(float)); + c_orig_float = (float *) malloc (m * n * sizeof(float)); + + /* randomize matrices with float vals */ + bli_srandv(m*k, a_float, 1); + bli_srandv(k*n, b_float, 1); + bli_srandv(m*n, c_orig_float, 1); + + /* normalize the matrices */ + normalize_vec(a_float, m*k); + normalize_vec(b_float, k*n); + normalize_vec(c_orig_float, m*n); + + /* cast the float buffers into the buffers for the kernel */ + cast_f32_to_i4m (a_float, a, m*k); + cast_f32_to_i4m (b_float, b, k*n); + + /* cast float buffers to support int values */ + cast_f32_to_i32m(c_orig_float, c_orig, m*n); + cast_i32_to_f32m(c_orig, c_orig_float, m*n); + + /* cast the kernel buffers into the float buffers to ensure that the values match */ + cast_i4_to_f32m (a, a_float, m*k); + cast_i4_to_f32m (b, b_float, k*n); + + /* init alpha and beta */ + alpha = 1; + beta = 1; + + /* run kernel to get result in c_ans */ + // strides need to be adjusted since 1 element stores 2 values + memcpy(c_ans, c_orig, m * n * sizeof(int)); + bli_i4gemm( + BLIS_NO_TRANSPOSE, + BLIS_NO_TRANSPOSE, + m, + n, + k, + &alpha, + a, rsa/2, csa, + b, rsb/2, csb, + &beta, + c_ans, rsc, csc + ); + + /* cast integer result into float buffer since float is our std format for correctness checking */ + cast_i32_to_f32m(c_ans, c_ans_float, m*n); + + /* using the BLIS GEMM correctness check method, get the resid */ + correctness_checker( + m, n, k, + a_float, rsa, csa, + b_float, rsb, csb, + c_orig_float, rsc, csc, + c_ans_float, + (float) alpha, (float) beta + ); + + free(a); + free(b); + free(c_ans); + free(c_orig); + free(a_float); + free(b_float); + free(c_ans_float); + free(c_orig_float); +} + +// using the DATATYPE enum, gather test the correctness of the respective GEMM kernel +void run_correctness_kernel(int kernel_id, int m, int n, int k) +{ + switch (kernel_id) + { + case FLOAT16 : shcorrectness_kernel(m, n, k); break; + case BFLOAT16: sbcorrectness_kernel(m, n, k); break; + case INT16 : i16correctness_kernel(m, n, k); break; + case INT8 : i8correctness_kernel(m, n, k); break; + case INT4 : i4correctness_kernel(m, n, k); break; + default: break; + } +} + +void test_correctness(int kernel_id, int start, int end, int inc) +{ + printf("%s correctness test\n", get_kernel_name(kernel_id)); + printf("m, n, k, resid\n"); + int m,n,k; + for (int p=start; p<=end; p+=inc) + { + m=n=k=p; + run_correctness_kernel(kernel_id, m, n, k); + } +} + +// correctness test for bfloat16 gemm +int main(int argc, char *argv[]) +{ + + test_correctness(FLOAT16, 80, 4000, 80); + test_correctness(BFLOAT16, 80, 4000, 80); + test_correctness(INT16, 80, 4000, 80); + test_correctness(INT8, 80, 4000, 80); + test_correctness(INT4, 80, 4000, 80); +} diff --git a/sandbox/power10/p10_testsuite/correctness.h b/sandbox/power10/p10_testsuite/correctness.h new file mode 100644 index 0000000000..aea647848a --- /dev/null +++ b/sandbox/power10/p10_testsuite/correctness.h @@ -0,0 +1,176 @@ +// templates for generating correctness checking functions that check the correctness of GEMM kernels +// using the BLIS GEMM correctness method + +#define COR_KERNEL_NAME_(ch) ch ## correctness_kernel +#define COR_KERNEL_NAME(ch) COR_KERNEL_NAME_(ch) + + +// correctness template for float types +#define GEN_FP_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \ +void COR_KERNEL_NAME(ch) (int m, int n, int k) \ +{ \ + int rsa = k, csa = 1, \ + rsb = n, csb = 1, \ + rsc = n, csc = 1; \ +\ + input_t *a, *b; \ +\ + float *a_float, *b_float, \ + *c_ans_float, *c_orig_float, \ + alpha, beta; \ +\ + /* buffers that will be passed into the kernel */ \ + a = (input_t *) malloc (m * k * sizeof(input_t)); \ + b = (input_t *) malloc (k * n * sizeof(input_t)); \ +\ + /* std format buffers that will be used by the correctness checker */ \ + a_float = (float *) malloc (m * k * sizeof(float)); \ + b_float = (float *) malloc (k * n * sizeof(float)); \ + c_ans_float = (float *) malloc (m * n * sizeof(float)); \ + c_orig_float = (float *) malloc (m * n * sizeof(float)); \ +\ + /* randomize matrices with float vals */ \ + bli_srandv(m*k, a_float, 1); \ + bli_srandv(k*n, b_float, 1); \ + bli_srandv(m*n, c_orig_float, 1); \ +\ + /* normalize the matrices */ \ + normalize_vec(a_float, m*k); \ + normalize_vec(b_float, k*n); \ + normalize_vec(c_orig_float, m*n); \ +\ + /* cast the float buffers into the buffers for the kernel */ \ + DOWN_CAST (a_float, a, m*k); \ + DOWN_CAST (b_float, b, k*n); \ +\ + /* cast the kernel buffers into the float buffers to ensure that the values match */ \ + UP_CAST (a, a_float, m*k); \ + UP_CAST (b, b_float, k*n); \ +\ + /* init alpha and beta */ \ + alpha = 1; \ + beta = 1; \ +\ + memcpy(c_ans_float, c_orig_float, m * n * sizeof(float)); \ + kernel( \ + BLIS_NO_TRANSPOSE, \ + BLIS_NO_TRANSPOSE, \ + m, \ + n, \ + k, \ + &alpha, \ + a, rsa, csa, \ + b, rsb, csb, \ + &beta, \ + c_ans_float, rsc, csc \ + ); \ +\ + correctness_checker( \ + m, n, k, \ + a_float, rsa, csa, \ + b_float, rsb, csb, \ + c_orig_float, rsc, csc, \ + c_ans_float, \ + alpha, beta \ + ); \ +\ + free(a); \ + free(b); \ + free(a_float); \ + free(b_float); \ + free(c_ans_float); \ + free(c_orig_float); \ +\ +} + +// correctness template for int types +#define GEN_I_COR_KERNEL(ch, kernel, input_t, DOWN_CAST, UP_CAST) \ +void COR_KERNEL_NAME(ch) (int m, int n, int k) \ +{ \ + int rsa = k, csa = 1, \ + rsb = n, csb = 1, \ + rsc = n, csc = 1; \ +\ + input_t *a, *b; \ +\ + int32_t *c_ans, *c_orig, alpha, beta; \ +\ + float *a_float, *b_float, \ + *c_ans_float, *c_orig_float; \ +\ + /* buffers that will be passed into the kernel */ \ + a = (input_t *) malloc (m * k * sizeof(input_t)); \ + b = (input_t *) malloc (k * n * sizeof(input_t)); \ + c_ans = (int32_t *) malloc (m * n * sizeof(int32_t)); \ + c_orig = (int32_t *) malloc (m * n * sizeof(int32_t)); \ +\ + /* std format buffers that will be used by the correctness checker */ \ + a_float = (float *) malloc (m * k * sizeof(float)); \ + b_float = (float *) malloc (k * n * sizeof(float)); \ + c_ans_float = (float *) malloc (m * n * sizeof(float)); \ + c_orig_float = (float *) malloc (m * n * sizeof(float)); \ +\ + /* randomize matrices with float vals */ \ + bli_srandv(m*k, a_float, 1); \ + bli_srandv(k*n, b_float, 1); \ + bli_srandv(m*n, c_orig_float, 1); \ +\ + /* normalize the matrices */ \ + normalize_vec(a_float, m*k); \ + normalize_vec(b_float, k*n); \ + normalize_vec(c_orig_float, m*n); \ +\ + /* cast the float buffers into the buffers for the kernel */ \ + DOWN_CAST (a_float, a, m*k); \ + DOWN_CAST (b_float, b, k*n); \ +\ + /* cast float buffers to support int values */ \ + cast_f32_to_i32m(c_orig_float, c_orig, m*n); \ + cast_i32_to_f32m(c_orig, c_orig_float, m*n); \ +\ + /* cast the kernel buffers into the float buffers to ensure that the values match */ \ + UP_CAST (a, a_float, m*k); \ + UP_CAST (b, b_float, k*n); \ +\ + /* init alpha and beta */ \ + alpha = 1; \ + beta = 1; \ +\ + /* run kernel to get result in c_ans */ \ + memcpy(c_ans, c_orig, m * n * sizeof(int)); \ + kernel( \ + BLIS_NO_TRANSPOSE, \ + BLIS_NO_TRANSPOSE, \ + m, \ + n, \ + k, \ + &alpha, \ + a, rsa, csa, \ + b, rsb, csb, \ + &beta, \ + c_ans, rsc, csc \ + ); \ +\ + /* cast integer result into float buffer since float is our std format for correctness checking */ \ + cast_i32_to_f32m(c_ans, c_ans_float, m*n); \ +\ + /* using the BLIS GEMM correctness check method, get the resid */ \ + correctness_checker( \ + m, n, k, \ + a_float, rsa, csa, \ + b_float, rsb, csb, \ + c_orig_float, rsc, csc, \ + c_ans_float, \ + (float) alpha, (float) beta \ + ); \ +\ + free(a); \ + free(b); \ + free(c_ans); \ + free(c_orig); \ + free(a_float); \ + free(b_float); \ + free(c_ans_float); \ + free(c_orig_float); \ +\ +} diff --git a/sandbox/power10/p10_testsuite/performance.c b/sandbox/power10/p10_testsuite/performance.c new file mode 100644 index 0000000000..25f1c3ff2a --- /dev/null +++ b/sandbox/power10/p10_testsuite/performance.c @@ -0,0 +1,103 @@ +/* + + This program is designed to gather the performance data of the POWER10 + GEMM kernels in `blis/sandbox/power10`. + + By default, the performance of the kernels is gather over a set of square + matrices. The perfromance results are reported in GFLOPS, and outputted in + CSV format. + +*/ + +#include "performance.h" +#include "blis.h" +#include "../bli_sandbox.h" +#include "common.h" + +#include +// print kernel name +const char* get_kernel_name(int kernel_id) +{ + switch (kernel_id) + { + case FLOAT16 : return "bli_shgemm"; + case BFLOAT16: return "bli_sbgemm"; + case INT16 : return "bli_i16gemm"; + case INT8 : return "bli_i8gemm"; + case INT4 : return "bli_i4gemm"; + default: printf("INCORRECT KERNEL ID\n"); exit(-1); + } +} + +// create all the performance gathering functions for each kernel +GET_PERF_API_TEMP(sb, bli_sbgemm, bfloat16, float); +GET_PERF_API_TEMP(sh, bli_shgemm, float16, float); +GET_PERF_API_TEMP(i16, bli_i16gemm, int16_t, int); +GET_PERF_API_TEMP(i8, bli_i8gemm, int8_t, int); +GET_PERF_API_TEMP(i4, bli_i4gemm, nibbles, int); + + +// using the DATATYPE enum, gather the performance of the respective GEMM kernel +double run_kernel(int kernel_id, int nreps, int m, int n, int k) +{ + switch (kernel_id) + { + case FLOAT16 : return test_shapi(nreps, m, n, k); + case BFLOAT16: return test_sbapi(nreps, m, n, k); + case INT16 : return test_i16api(nreps, m, n, k); + case INT8 : return test_i8api(nreps, m, n, k); + case INT4 : return test_i4api(nreps, m, n, k); + default: return -1.0; + } +} + +// print the performance data in CSV format +// performance is measured in terms of GFLOPs +void print_perf_data(int m, int n, int k, double best_time) +{ + double GFLOPS = (2.0 * m * n * k) / (1e9 * best_time); + printf("%d, %d, %d, %.2f\n", m, n, k, GFLOPS); +} + +// get performance data +void get_perf(int kernel_id, int nreps, int start, int end, int inc) +{ + // csv header + printf("%s performance\n", get_kernel_name(kernel_id)); + printf("m, n, k, GFLOPS\n"); + + int m,n,k; + + // run over all problem sizes + for (int p=start; p<=end; p+=inc) + { + // change here to adjust problem size + m = p, + n = p, + k = p; + + double best_run_time = run_kernel(kernel_id, nreps, m, n, k); + + print_perf_data(m, n, k, best_run_time); + } +} + +int main(int argc, char *argv[]) +{ + // initialize a square problem set range + int start = 80; + int end = 4000; + int inc = 80; + + // number of times the kernel will be run + int nreps = 5; + + // run a respective kernel + get_perf( FLOAT16, nreps, start, end, inc); + get_perf(BFLOAT16, nreps, start, end, inc); + get_perf( INT16, nreps, start, end, inc); + get_perf( INT8, nreps, start, end, inc); + get_perf( INT4, nreps, start, end, inc); + + return 0; +} diff --git a/sandbox/power10/p10_testsuite/performance.h b/sandbox/power10/p10_testsuite/performance.h new file mode 100644 index 0000000000..26c36f6155 --- /dev/null +++ b/sandbox/power10/p10_testsuite/performance.h @@ -0,0 +1,58 @@ + +// function name template +// each function that will gather perform will be named test_api +#define GEN_PERF_FUNC_NAME_(ch) test_ ## ch ## api +#define GEN_PERF_FUNC_NAME(ch) GEN_PERF_FUNC_NAME_(ch) + +/* + Macro template for getting the best GEMM kernel runtime out of `num_runs` + for matrices of size (m x n x k). +*/ +#define GET_PERF_API_TEMP(ch, kernel, input_t, output_t) \ +double GEN_PERF_FUNC_NAME(ch) ( \ + int num_runs, \ + int m, \ + int n, \ + int k \ +) \ +{ \ + input_t *A,*B; \ + output_t *C; \ + output_t alpha,beta; \ +\ + A = (input_t*) malloc(m*k*sizeof(input_t)); \ + B = (input_t*) malloc(n*k*sizeof(input_t)); \ + C = (output_t*) malloc(m*n*sizeof(output_t)); \ + \ + alpha = 1; \ + beta = 1; \ + \ + double best = 1e9; \ + \ + for (int irep=0; irep ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour} + COMMENT "Running test_libblis.x ${printflavour} with output redirected to ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour}" + DEPENDS test_libblis.x ${CMAKE_CURRENT_SOURCE_DIR}/input.general${dotflavour} ${CMAKE_CURRENT_SOURCE_DIR}/input.operations${dotflavour} + BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour} + WORKING_DIRECTORY $ + VERBATIM + ) + # Check the results of the BLIS testsuite. + add_custom_target(checkblis${dashflavour} + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/build/cmake/check-blistest.py ${CMAKE_CURRENT_BINARY_DIR}/output.testsuite${dotflavour} + DEPENDS testblis${dashflavour} + ) +endfunction() +# Add testing targets using functions above for all input file options. +add_testblis("") +add_testblis("fast") +add_testblis("mixed") +add_testblis("salt") +add_custom_target(checkblis-md DEPENDS checkblis-mixed) +add_custom_target(testblis-md DEPENDS testblis-mixed) +add_custom_target(testsuite DEPENDS testblis) +# Put all those targets under testsuite-targets folder name so that they appear all together in IDE. +set_target_properties(test_libblis.x testblis checkblis testblis-fast checkblis-fast testblis-md checkblis-md testblis-mixed checkblis-mixed testblis-salt checkblis-salt + PROPERTIES FOLDER testsuite-targets) diff --git a/testsuite/src/CMakeLists.txt b/testsuite/src/CMakeLists.txt deleted file mode 100644 index 7180ac1ca6..0000000000 --- a/testsuite/src/CMakeLists.txt +++ /dev/null @@ -1,60 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved## - -target_sources(test_libblis - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/test_addm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_addv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_amaxv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_axpbyv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_axpy2v.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_axpyf.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_axpym.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_axpyv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_copym.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_copyv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_dotaxpyv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_dotv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_dotxaxpyf.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_dotxf.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_dotxv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemmt.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_ukr.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemmtrsm_ukr.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_gemv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_ger.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_hemm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_hemv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_her.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_her2.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_her2k.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_herk.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_libblis.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_normfm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_normfv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_randm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_randv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_scal2m.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_scal2v.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_scalm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_scalv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_setm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_setv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_subm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_subv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_symm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_symv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_syr.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_syr2.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_syr2k.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_syrk.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trmm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trmm3.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trmv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trsm.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trsm_ukr.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_trsv.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_xpbym.c - ${CMAKE_CURRENT_SOURCE_DIR}/test_xpbyv.c - ) - diff --git a/testsuite/src/test_addm.c b/testsuite/src/test_addm.c index f7c21b733d..64b169cd6d 100644 --- a/testsuite/src/test_addm.c +++ b/testsuite/src/test_addm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_addm.h b/testsuite/src/test_addm.h index 0dbdbfa2ee..edccd19e23 100644 --- a/testsuite/src/test_addm.h +++ b/testsuite/src/test_addm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_addv.c b/testsuite/src/test_addv.c index 9e216ab4d7..a0e2d05b67 100644 --- a/testsuite/src/test_addv.c +++ b/testsuite/src/test_addv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_addv.h b/testsuite/src/test_addv.h index eba5a9220e..bc09bb076c 100644 --- a/testsuite/src/test_addv.h +++ b/testsuite/src/test_addv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_amaxv.c b/testsuite/src/test_amaxv.c index fd6bad5f7f..29369ee63c 100644 --- a/testsuite/src/test_amaxv.c +++ b/testsuite/src/test_amaxv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_amaxv.h b/testsuite/src/test_amaxv.h index 46d87b37f4..3b5a6b6b5d 100644 --- a/testsuite/src/test_amaxv.h +++ b/testsuite/src/test_amaxv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpbyv.c b/testsuite/src/test_axpbyv.c index a82ff6e256..398d9134e3 100644 --- a/testsuite/src/test_axpbyv.c +++ b/testsuite/src/test_axpbyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpbyv.h b/testsuite/src/test_axpbyv.h index 9b318dba10..a31d196c0b 100644 --- a/testsuite/src/test_axpbyv.h +++ b/testsuite/src/test_axpbyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpy2v.c b/testsuite/src/test_axpy2v.c index eeebf15e73..e7bc5f6b23 100644 --- a/testsuite/src/test_axpy2v.c +++ b/testsuite/src/test_axpy2v.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpy2v.h b/testsuite/src/test_axpy2v.h index c695a643bb..1ed1df1347 100644 --- a/testsuite/src/test_axpy2v.h +++ b/testsuite/src/test_axpy2v.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyf.c b/testsuite/src/test_axpyf.c index 7a85b22123..39085d59cb 100644 --- a/testsuite/src/test_axpyf.c +++ b/testsuite/src/test_axpyf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyf.h b/testsuite/src/test_axpyf.h index 9dd1dadc29..e936e2c797 100644 --- a/testsuite/src/test_axpyf.h +++ b/testsuite/src/test_axpyf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpym.c b/testsuite/src/test_axpym.c index 222fda33db..dcd51d46e5 100644 --- a/testsuite/src/test_axpym.c +++ b/testsuite/src/test_axpym.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpym.h b/testsuite/src/test_axpym.h index 632720284d..429ad4a312 100644 --- a/testsuite/src/test_axpym.h +++ b/testsuite/src/test_axpym.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyv.c b/testsuite/src/test_axpyv.c index 81d4f37706..155299eb8b 100644 --- a/testsuite/src/test_axpyv.c +++ b/testsuite/src/test_axpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_axpyv.h b/testsuite/src/test_axpyv.h index c96a9096bb..376913e23e 100644 --- a/testsuite/src/test_axpyv.h +++ b/testsuite/src/test_axpyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copym.c b/testsuite/src/test_copym.c index 1aab1d287b..7cd1a9808b 100644 --- a/testsuite/src/test_copym.c +++ b/testsuite/src/test_copym.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copym.h b/testsuite/src/test_copym.h index 560de0e9a6..a82e781245 100644 --- a/testsuite/src/test_copym.h +++ b/testsuite/src/test_copym.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copyv.c b/testsuite/src/test_copyv.c index 4350e95ee6..d2ae09c24e 100644 --- a/testsuite/src/test_copyv.c +++ b/testsuite/src/test_copyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_copyv.h b/testsuite/src/test_copyv.h index 2beb3212d0..0fe3082aa2 100644 --- a/testsuite/src/test_copyv.h +++ b/testsuite/src/test_copyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotaxpyv.c b/testsuite/src/test_dotaxpyv.c index 391c119bbd..e3ac4cc676 100644 --- a/testsuite/src/test_dotaxpyv.c +++ b/testsuite/src/test_dotaxpyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotaxpyv.h b/testsuite/src/test_dotaxpyv.h index ce82227f49..dcbb870eda 100644 --- a/testsuite/src/test_dotaxpyv.h +++ b/testsuite/src/test_dotaxpyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotv.c b/testsuite/src/test_dotv.c index 347ce9e620..24b0672e68 100644 --- a/testsuite/src/test_dotv.c +++ b/testsuite/src/test_dotv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotv.h b/testsuite/src/test_dotv.h index 2f000128b1..b458fb8f42 100644 --- a/testsuite/src/test_dotv.h +++ b/testsuite/src/test_dotv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxaxpyf.c b/testsuite/src/test_dotxaxpyf.c index a2c3ef3e94..cf42fb2e38 100644 --- a/testsuite/src/test_dotxaxpyf.c +++ b/testsuite/src/test_dotxaxpyf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxaxpyf.h b/testsuite/src/test_dotxaxpyf.h index 6bfcd2655e..b81ddac290 100644 --- a/testsuite/src/test_dotxaxpyf.h +++ b/testsuite/src/test_dotxaxpyf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxf.c b/testsuite/src/test_dotxf.c index 8a1eca4eba..7045eaf4eb 100644 --- a/testsuite/src/test_dotxf.c +++ b/testsuite/src/test_dotxf.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxf.h b/testsuite/src/test_dotxf.h index 06cac584e3..34264a2df6 100644 --- a/testsuite/src/test_dotxf.h +++ b/testsuite/src/test_dotxf.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxv.c b/testsuite/src/test_dotxv.c index da42e6ae4d..f9fec4f57d 100644 --- a/testsuite/src/test_dotxv.c +++ b/testsuite/src/test_dotxv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_dotxv.h b/testsuite/src/test_dotxv.h index a3e2ca48f2..5e464a61ad 100644 --- a/testsuite/src/test_dotxv.h +++ b/testsuite/src/test_dotxv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 0fbf54df36..a4a62d24c4 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -405,7 +405,7 @@ void libblis_test_gemm_md } // Estimate the performance of the best experiment repeat. - //*perf = ( 2.0 * m * n * k ) / time_min / FLOPS_PER_UNIT_PERF; + // *perf = ( 2.0 * m * n * k ) / time_min / FLOPS_PER_UNIT_PERF; //if ( bli_obj_is_complex( &c ) ) *perf *= 4.0; *perf = libblis_test_gemm_flops( &a, &b, &c ) / time_min / FLOPS_PER_UNIT_PERF; @@ -438,7 +438,7 @@ void libblis_test_gemm_impl { case BLIS_TEST_SEQ_FRONT_END: bli_gemm( alpha, a, b, beta, c ); - break; + break; default: libblis_test_printf_error( "Invalid interface type.\n" ); diff --git a/testsuite/src/test_gemm.h b/testsuite/src/test_gemm.h index 78364bc249..dc73b59a2e 100644 --- a/testsuite/src/test_gemm.h +++ b/testsuite/src/test_gemm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemm_ukr.c b/testsuite/src/test_gemm_ukr.c index 48996f28e7..71824290f8 100644 --- a/testsuite/src/test_gemm_ukr.c +++ b/testsuite/src/test_gemm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -284,10 +284,10 @@ void libblis_test_gemm_ukr_experiment // allocated so we can re-store it to the object afterward. void* buf_ap = bli_obj_buffer( &ap ); void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, + bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_ROW_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, BLIS_MR, BLIS_KR, &a, &ap, cntx ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, + bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_COL_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, BLIS_KR, BLIS_NR, &b, &bp, cntx ); bli_obj_set_buffer( buf_ap, &ap ); diff --git a/testsuite/src/test_gemm_ukr.h b/testsuite/src/test_gemm_ukr.h index cd09ef3f69..bdb9e62222 100644 --- a/testsuite/src/test_gemm_ukr.h +++ b/testsuite/src/test_gemm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemmt.c b/testsuite/src/test_gemmt.c index af61eff6e2..c99bd24944 100644 --- a/testsuite/src/test_gemmt.c +++ b/testsuite/src/test_gemmt.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemmt.h b/testsuite/src/test_gemmt.h index 5468bf48dc..dbcb28dc29 100644 --- a/testsuite/src/test_gemmt.h +++ b/testsuite/src/test_gemmt.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index 7ce7034453..34e45645d3 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -209,38 +209,6 @@ void libblis_test_gemmtrsm_ukr_experiment // Query a context. cntx = bli_gks_query_cntx(); - // If TRSM and GEMM have different blocksizes and blocksizes - // are changed in global cntx object, when GEMM and TRSM are - // called in parallel, blocksizes in global cntx object will - // not be correct - // to fix this a local copy of cntx is created, so that - // overriding the blocksizes does not impact the global cntx - // object. - // This is a temporary fix, a better fix is to create a - // separate blocksz_trsm array in cntx. - cntx_t cntx_trsm = *cntx; - -#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) - /* Zen4 TRSM Fixme: - * - * TRSM and GEMM used different values of MR and NR, we need to ensure that - * Values used for packing are as per the MR and NR values expected by the kernels - * For now this issue exists only for zen4 hence override the values here if - * the family is BLIS_TRSM and architecture is zen4 - * - * We need to override the values here as well as the packing and compute - * kernels are invoked directly from here (instead of BLIS/BLAS call.) - * - * We need to revisit this when TRSM AVX-512 kernels are implemented. - */ - if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && - ((dc_str[0] == 's') || (dc_str[0] == 'd') || - (dc_str[0] == 'S') || (dc_str[0] == 'D')) ) - { - bli_zen4_override_trsm_blkszs(&cntx_trsm); - } -#endif - // Use the datatype of the first char in the datatype combination string. bli_param_map_char_to_blis_dt( dc_str[0], &datatype ); @@ -248,14 +216,25 @@ void libblis_test_gemmtrsm_ukr_experiment k = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur ); - m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, &cntx_trsm ); - n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, &cntx_trsm ); + m = bli_cntx_get_trsm_blksz_def_dt( datatype, BLIS_MR, cntx ); + n = bli_cntx_get_trsm_blksz_def_dt( datatype, BLIS_NR, cntx ); // Also query PACKMR and PACKNR as the leading dimensions to ap and bp, // respectively. - ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, &cntx_trsm ); - ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, &cntx_trsm); + ldap = bli_cntx_get_trsm_blksz_max_dt( datatype, BLIS_MR, cntx ); + ldbp = bli_cntx_get_trsm_blksz_max_dt( datatype, BLIS_NR, cntx); + // if trsm block sizes are not set use global block sizes + if( m == 0 || n == 0) + { + m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, cntx ); + n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, cntx ); + + // Also query PACKMR and PACKNR as the leading dimensions to ap and bp, + // respectively. + ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, cntx ); + ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, cntx); + } // Store the register blocksizes so that the driver can retrieve the // values later when printing results. @@ -372,12 +351,12 @@ void libblis_test_gemmtrsm_ukr_experiment // allocated so we can re-store it to the object afterward. void* buf_ap = bli_obj_buffer( &ap ); void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, + bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_TRSM, BLIS_PACKED_ROW_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_MR, BLIS_KR, &a, &ap, &cntx_trsm ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, + BLIS_MR, BLIS_KR, &a, &ap, cntx ); + bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_TRSM, BLIS_PACKED_COL_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, - BLIS_KR, BLIS_NR, &b, &bp, &cntx_trsm ); + BLIS_KR, BLIS_NR, &b, &bp, cntx ); bli_obj_set_buffer( buf_ap, &ap ); bli_obj_set_buffer( buf_bp, &bp ); @@ -391,8 +370,8 @@ void libblis_test_gemmtrsm_ukr_experiment bli_obj_set_uplo( uploa, &ap ); // Pack the data from the source objects. - bli_packm_blk_var1( &a, &ap, &cntx_trsm, NULL, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, &cntx_trsm, NULL, &BLIS_PACKM_SINGLE_THREADED ); + bli_packm_blk_var1( &a, &ap, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); + bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); // Create subpartitions from the a and b panels. bli_gemmtrsm_ukr_make_subparts( k, &ap, &bp, @@ -415,13 +394,13 @@ bli_printm( "ap", &ap, "%5.2f", "" ); // Re-pack (restore) the contents of b to bp. //bli_packm_blk_var1( &b, &bp, &cntx, cntl_b, &BLIS_PACKM_SINGLE_THREADED ); - bli_packm_blk_var1( &b, &bp, &cntx_trsm, NULL, &BLIS_PACKM_SINGLE_THREADED ); + bli_packm_blk_var1( &b, &bp, cntx, NULL, &BLIS_PACKM_SINGLE_THREADED ); time = bli_clock(); libblis_test_gemmtrsm_ukr_impl( iface, side, &alpha, &a1xp, &a11p, &bx1p, &b11p, &c11, - &cntx_trsm ); + cntx ); time_min = bli_clock_min_diff( time_min, time ); } diff --git a/testsuite/src/test_gemmtrsm_ukr.h b/testsuite/src/test_gemmtrsm_ukr.h index 5fd3cc0ba0..bf26bb3938 100644 --- a/testsuite/src/test_gemmtrsm_ukr.h +++ b/testsuite/src/test_gemmtrsm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemv.c b/testsuite/src/test_gemv.c index e6090e1c5b..721152ef53 100644 --- a/testsuite/src/test_gemv.c +++ b/testsuite/src/test_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_gemv.h b/testsuite/src/test_gemv.h index 8e7284486a..051e951c30 100644 --- a/testsuite/src/test_gemv.h +++ b/testsuite/src/test_gemv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_ger.c b/testsuite/src/test_ger.c index b44fe6ba64..e61af0bcd4 100644 --- a/testsuite/src/test_ger.c +++ b/testsuite/src/test_ger.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_ger.h b/testsuite/src/test_ger.h index 5b75babe60..99e00b5cb2 100644 --- a/testsuite/src/test_ger.h +++ b/testsuite/src/test_ger.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemm.c b/testsuite/src/test_hemm.c index 0145dd0dfd..faff475969 100644 --- a/testsuite/src/test_hemm.c +++ b/testsuite/src/test_hemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemm.h b/testsuite/src/test_hemm.h index 7db76afa1e..b84fdda46a 100644 --- a/testsuite/src/test_hemm.h +++ b/testsuite/src/test_hemm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemv.c b/testsuite/src/test_hemv.c index 02e205392b..96475d1d69 100644 --- a/testsuite/src/test_hemv.c +++ b/testsuite/src/test_hemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_hemv.h b/testsuite/src/test_hemv.h index e522690d1e..48f902400b 100644 --- a/testsuite/src/test_hemv.h +++ b/testsuite/src/test_hemv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her.c b/testsuite/src/test_her.c index c122f6ce56..d7853913cf 100644 --- a/testsuite/src/test_her.c +++ b/testsuite/src/test_her.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her.h b/testsuite/src/test_her.h index a6aaa55b47..2033f3dc1f 100644 --- a/testsuite/src/test_her.h +++ b/testsuite/src/test_her.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2.c b/testsuite/src/test_her2.c index 1ed6b3bb9e..3fb89b1fd7 100644 --- a/testsuite/src/test_her2.c +++ b/testsuite/src/test_her2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2.h b/testsuite/src/test_her2.h index c2711cfb11..9fd037ec62 100644 --- a/testsuite/src/test_her2.h +++ b/testsuite/src/test_her2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2k.c b/testsuite/src/test_her2k.c index 0158e25a25..8c4b69d4cb 100644 --- a/testsuite/src/test_her2k.c +++ b/testsuite/src/test_her2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_her2k.h b/testsuite/src/test_her2k.h index a481dac720..4d03afa77e 100644 --- a/testsuite/src/test_her2k.h +++ b/testsuite/src/test_her2k.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_herk.c b/testsuite/src/test_herk.c index abe4e70b10..0c9a8eb437 100644 --- a/testsuite/src/test_herk.c +++ b/testsuite/src/test_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_herk.h b/testsuite/src/test_herk.h index 1702bd8b9b..24bcceacae 100644 --- a/testsuite/src/test_herk.h +++ b/testsuite/src/test_herk.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index b904094267..566701bfcc 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -121,6 +121,8 @@ void* libblis_test_thread_entry( void* tdata_void ) void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) { + err_t r_val; + // Query the total number of threads to simulate. size_t nt = ( size_t )params->n_app_threads; @@ -130,12 +132,12 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_t* pthread = bli_malloc_user( sizeof( bli_pthread_t ) * nt ); + bli_pthread_t* pthread = bli_malloc_user( sizeof( bli_pthread_t ) * nt, &r_val ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - thread_data_t* tdata = bli_malloc_user( sizeof( thread_data_t ) * nt ); + thread_data_t* tdata = bli_malloc_user( sizeof( thread_data_t ) * nt, &r_val ); // Allocate a mutex for the threads to share. //bli_pthread_mutex_t* mutex = bli_malloc_user( sizeof( bli_pthread_mutex_t ) ); @@ -145,7 +147,7 @@ void libblis_test_thread_decorator( test_params_t* params, test_ops_t* ops ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "libblis_test_thread_decorator(): " ); #endif - bli_pthread_barrier_t* barrier = bli_malloc_user( sizeof( bli_pthread_barrier_t ) ); + bli_pthread_barrier_t* barrier = bli_malloc_user( sizeof( bli_pthread_barrier_t ), &r_val ); // Initialize the mutex. //bli_pthread_mutex_init( mutex, NULL ); @@ -2037,6 +2039,7 @@ void libblis_test_op_driver bli_abort(); #endif + free(chars_for_dt); } else // ( ( !mixed_domain && !mixed_precision ) || op->opid != BLIS_GEMM ) { diff --git a/testsuite/src/test_libblis.h b/testsuite/src/test_libblis.h index 786f82b308..80ca67c669 100644 --- a/testsuite/src/test_libblis.h +++ b/testsuite/src/test_libblis.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_normfm.c b/testsuite/src/test_normfm.c index c4b9a0105e..c5ebd61200 100644 --- a/testsuite/src/test_normfm.c +++ b/testsuite/src/test_normfm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_normfm.h b/testsuite/src/test_normfm.h index a24b5e5ba2..f0b25876d1 100644 --- a/testsuite/src/test_normfm.h +++ b/testsuite/src/test_normfm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_normfv.c b/testsuite/src/test_normfv.c index 3bcce35af4..8b01e65ea3 100644 --- a/testsuite/src/test_normfv.c +++ b/testsuite/src/test_normfv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_normfv.h b/testsuite/src/test_normfv.h index afa5350063..f15887558c 100644 --- a/testsuite/src/test_normfv.h +++ b/testsuite/src/test_normfv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randm.c b/testsuite/src/test_randm.c index 223007dba9..50c93a5e68 100644 --- a/testsuite/src/test_randm.c +++ b/testsuite/src/test_randm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randm.h b/testsuite/src/test_randm.h index e444649629..97db920602 100644 --- a/testsuite/src/test_randm.h +++ b/testsuite/src/test_randm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randv.c b/testsuite/src/test_randv.c index 951c8c3eca..a74a565a7f 100644 --- a/testsuite/src/test_randv.c +++ b/testsuite/src/test_randv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_randv.h b/testsuite/src/test_randv.h index bb658dfd7c..202829f3a1 100644 --- a/testsuite/src/test_randv.h +++ b/testsuite/src/test_randv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2m.c b/testsuite/src/test_scal2m.c index e8440fc46d..69a1a0a895 100644 --- a/testsuite/src/test_scal2m.c +++ b/testsuite/src/test_scal2m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2m.h b/testsuite/src/test_scal2m.h index 262723f4e7..9937adac56 100644 --- a/testsuite/src/test_scal2m.h +++ b/testsuite/src/test_scal2m.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2v.c b/testsuite/src/test_scal2v.c index c200e13fcb..77c68ec3fc 100644 --- a/testsuite/src/test_scal2v.c +++ b/testsuite/src/test_scal2v.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scal2v.h b/testsuite/src/test_scal2v.h index 75b5cfe4a6..a8baca33fd 100644 --- a/testsuite/src/test_scal2v.h +++ b/testsuite/src/test_scal2v.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalm.c b/testsuite/src/test_scalm.c index 6219c71df4..2ea9a4028f 100644 --- a/testsuite/src/test_scalm.c +++ b/testsuite/src/test_scalm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalm.h b/testsuite/src/test_scalm.h index 3b98617b29..6897ada6be 100644 --- a/testsuite/src/test_scalm.h +++ b/testsuite/src/test_scalm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalv.c b/testsuite/src/test_scalv.c index 142b5e410b..ff27f76ebd 100644 --- a/testsuite/src/test_scalv.c +++ b/testsuite/src/test_scalv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_scalv.h b/testsuite/src/test_scalv.h index 144b416759..89f63aace2 100644 --- a/testsuite/src/test_scalv.h +++ b/testsuite/src/test_scalv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setm.c b/testsuite/src/test_setm.c index 80cebd64e0..4a039f2289 100644 --- a/testsuite/src/test_setm.c +++ b/testsuite/src/test_setm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setm.h b/testsuite/src/test_setm.h index 0271840312..0b056cac52 100644 --- a/testsuite/src/test_setm.h +++ b/testsuite/src/test_setm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setv.c b/testsuite/src/test_setv.c index 10f0348c75..a8444b3a5c 100644 --- a/testsuite/src/test_setv.c +++ b/testsuite/src/test_setv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_setv.h b/testsuite/src/test_setv.h index 4e02d489e5..b225d9140a 100644 --- a/testsuite/src/test_setv.h +++ b/testsuite/src/test_setv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subm.c b/testsuite/src/test_subm.c index 63b48eedcf..35f67c80cd 100644 --- a/testsuite/src/test_subm.c +++ b/testsuite/src/test_subm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subm.h b/testsuite/src/test_subm.h index e39eff8282..d5074c3c33 100644 --- a/testsuite/src/test_subm.h +++ b/testsuite/src/test_subm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subv.c b/testsuite/src/test_subv.c index 3a48f02a46..b0b6715637 100644 --- a/testsuite/src/test_subv.c +++ b/testsuite/src/test_subv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_subv.h b/testsuite/src/test_subv.h index 5dbe465898..f4d83404f4 100644 --- a/testsuite/src/test_subv.h +++ b/testsuite/src/test_subv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symm.c b/testsuite/src/test_symm.c index 2ac7b41068..50d6315bd2 100644 --- a/testsuite/src/test_symm.c +++ b/testsuite/src/test_symm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symm.h b/testsuite/src/test_symm.h index bf50bf65d7..c6831ed0ed 100644 --- a/testsuite/src/test_symm.h +++ b/testsuite/src/test_symm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symv.c b/testsuite/src/test_symv.c index 5ae5f30be0..9c00ec082a 100644 --- a/testsuite/src/test_symv.c +++ b/testsuite/src/test_symv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_symv.h b/testsuite/src/test_symv.h index 5dba0624ca..305ff95469 100644 --- a/testsuite/src/test_symv.h +++ b/testsuite/src/test_symv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr.c b/testsuite/src/test_syr.c index 69376b9708..344bf0f159 100644 --- a/testsuite/src/test_syr.c +++ b/testsuite/src/test_syr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr.h b/testsuite/src/test_syr.h index 455e18ff1d..8d9433cf28 100644 --- a/testsuite/src/test_syr.h +++ b/testsuite/src/test_syr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2.c b/testsuite/src/test_syr2.c index 42d65c00e4..3daa7efe06 100644 --- a/testsuite/src/test_syr2.c +++ b/testsuite/src/test_syr2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2.h b/testsuite/src/test_syr2.h index d6c1f3c104..d614366d72 100644 --- a/testsuite/src/test_syr2.h +++ b/testsuite/src/test_syr2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2k.c b/testsuite/src/test_syr2k.c index 4d83bb88c8..7c86bd46fc 100644 --- a/testsuite/src/test_syr2k.c +++ b/testsuite/src/test_syr2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syr2k.h b/testsuite/src/test_syr2k.h index edf893c291..5514e1c3aa 100644 --- a/testsuite/src/test_syr2k.h +++ b/testsuite/src/test_syr2k.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syrk.c b/testsuite/src/test_syrk.c index 65d978bb03..e54471edec 100644 --- a/testsuite/src/test_syrk.c +++ b/testsuite/src/test_syrk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_syrk.h b/testsuite/src/test_syrk.h index 8cad724566..a4b0608423 100644 --- a/testsuite/src/test_syrk.h +++ b/testsuite/src/test_syrk.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm.c b/testsuite/src/test_trmm.c index a1decd37c9..24f00dc5b2 100644 --- a/testsuite/src/test_trmm.c +++ b/testsuite/src/test_trmm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm.h b/testsuite/src/test_trmm.h index a84ca1d296..d89673d0b0 100644 --- a/testsuite/src/test_trmm.h +++ b/testsuite/src/test_trmm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm3.c b/testsuite/src/test_trmm3.c index 17ba2190b9..7c1789eb30 100644 --- a/testsuite/src/test_trmm3.c +++ b/testsuite/src/test_trmm3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmm3.h b/testsuite/src/test_trmm3.h index ee9490036c..666577a74d 100644 --- a/testsuite/src/test_trmm3.h +++ b/testsuite/src/test_trmm3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmv.c b/testsuite/src/test_trmv.c index 71acc90ba0..76412d9d89 100644 --- a/testsuite/src/test_trmv.c +++ b/testsuite/src/test_trmv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trmv.h b/testsuite/src/test_trmv.h index 1fae8331ff..d579be2e33 100644 --- a/testsuite/src/test_trmv.h +++ b/testsuite/src/test_trmv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsm.c b/testsuite/src/test_trsm.c index fa0d8e7c30..8e9346480f 100644 --- a/testsuite/src/test_trsm.c +++ b/testsuite/src/test_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsm.h b/testsuite/src/test_trsm.h index ee23b2c7a1..c1fdb31c58 100644 --- a/testsuite/src/test_trsm.h +++ b/testsuite/src/test_trsm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsm_ukr.c b/testsuite/src/test_trsm_ukr.c index 6366e5fc3c..2060ab1847 100644 --- a/testsuite/src/test_trsm_ukr.c +++ b/testsuite/src/test_trsm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -285,10 +285,12 @@ void libblis_test_trsm_ukr_experiment // allocated so we can re-store it to the object afterward. void* buf_ap = bli_obj_buffer( &ap ); void* buf_bp = bli_obj_buffer( &bp ); - bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_PACKED_ROW_PANELS, + // trsm_ukr are derived from gemm kernels therefore packing is done with + // gemm blocksizes + bli_packm_init_pack( BLIS_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_ROW_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, BLIS_MR, BLIS_KR, &a, &ap, cntx ); - bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_PACKED_COL_PANELS, + bli_packm_init_pack( BLIS_NO_INVERT_DIAG, BLIS_GEMM, BLIS_PACKED_COL_PANELS, BLIS_PACK_FWD_IF_UPPER, BLIS_PACK_FWD_IF_LOWER, BLIS_KR, BLIS_NR, &b, &bp, cntx ); bli_obj_set_buffer( buf_ap, &ap ); diff --git a/testsuite/src/test_trsm_ukr.h b/testsuite/src/test_trsm_ukr.h index 22c6676368..2efaf2bae6 100644 --- a/testsuite/src/test_trsm_ukr.h +++ b/testsuite/src/test_trsm_ukr.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsv.c b/testsuite/src/test_trsv.c index 12543cd9a0..f7c3c2272a 100644 --- a/testsuite/src/test_trsv.c +++ b/testsuite/src/test_trsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_trsv.h b/testsuite/src/test_trsv.h index 5f5fa4eb0f..3dab55de32 100644 --- a/testsuite/src/test_trsv.h +++ b/testsuite/src/test_trsv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_xpbyv.c b/testsuite/src/test_xpbyv.c index 197de86e71..e001672038 100644 --- a/testsuite/src/test_xpbyv.c +++ b/testsuite/src/test_xpbyv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_xpbyv.h b/testsuite/src/test_xpbyv.h index 16eb772164..4474cdb688 100644 --- a/testsuite/src/test_xpbyv.h +++ b/testsuite/src/test_xpbyv.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/travis/cxx/Makefile b/travis/cxx/Makefile new file mode 100644 index 0000000000..0f8da14e3b --- /dev/null +++ b/travis/cxx/Makefile @@ -0,0 +1,38 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2021, Southern Methodist University +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +.PHONY: all cxx-test + +all: cxx-test + $(CXX) -std=c++0x -o $(BUILD_DIR)/cxx-test.x -I$(INCLUDE_DIR) cxx-test.cxx -L$(LIB_DIR) -lblis diff --git a/travis/cxx/cxx-test.cxx b/travis/cxx/cxx-test.cxx new file mode 100644 index 0000000000..bccbd9e430 --- /dev/null +++ b/travis/cxx/cxx-test.cxx @@ -0,0 +1,50 @@ +// +// +// BLIS +// An object-based framework for developing high-performance BLAS-like +// libraries. +// +// Copyright (C) 2021, Southern Methodist University +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// - Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// - Neither the name(s) of the copyright holder(s) nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// + +#include + +#include "blis.h" + +int main() +{ + const int N = 5; + std::vector A(N*N), B(N*N), C(N*N); + scomplex one{1.0, 0.0}; + scomplex zero{0.0, 0.0}; + + bli_cgemm(BLIS_NO_TRANSPOSE, BLIS_NO_TRANSPOSE, N, N, N, + &one, A.data(), 1, N, + B.data(), 1, N, + &zero, C.data(), 1, N); +} diff --git a/travis/cxx/cxx-test.sh b/travis/cxx/cxx-test.sh new file mode 100755 index 0000000000..c0036611f4 --- /dev/null +++ b/travis/cxx/cxx-test.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2021, Southern Methodist University +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +SOURCE_DIR=$1 +CONFIG=$2 + +if [ -z $SOURCE_DIR ] || [ -z $CONFIG ]; then + echo "usage: cxx-test.sh " + exit 1 +fi + +BUILD_DIR=$(pwd) +INCLUDE_DIR=$BUILD_DIR/include/$CONFIG +LIB_DIR=$BUILD_DIR/lib/$CONFIG + +if [ ! -e $INCLUDE_DIR/blis.h ]; then + echo "could not find blis.h" + exit 1 +fi + +if [ ! -e $SOURCE_DIR/travis/cxx/Makefile ]; then + echo "could not find cxx-test Makefile" + exit 1 +fi + +make -C $SOURCE_DIR/travis/cxx INCLUDE_DIR=$INCLUDE_DIR LIB_DIR=$LIB_DIR BUILD_DIR=$BUILD_DIR diff --git a/travis/do_testsuite.sh b/travis/do_testsuite.sh index bb176b6819..6778f81d85 100755 --- a/travis/do_testsuite.sh +++ b/travis/do_testsuite.sh @@ -8,19 +8,28 @@ export BLIS_IC_NT=2 export BLIS_JR_NT=1 export BLIS_IR_NT=1 -if [ "$TEST" = "FAST" ]; then +if [ "$TEST" = "FAST" -o "$TEST" = "ALL" ]; then make testblis-fast -elif [ "$TEST" = "MD" ]; then + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "MD" -o "$TEST" = "ALL" ]; then make testblis-md -elif [ "$TEST" = "SALT" ]; then + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "SALT" -o "$TEST" = "ALL" ]; then # Disable multithreading within BLIS. export BLIS_JC_NT=1 BLIS_IC_NT=1 BLIS_JR_NT=1 BLIS_IR_NT=1 make testblis-salt -else + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite +fi + +if [ "$TEST" = "1" -o "$TEST" = "ALL" ]; then make testblis + $DIST_PATH/testsuite/check-blistest.sh ./output.testsuite fi -$DIST_PATH/testsuite/check-blistest.sh ./output.testsuite make testblas $DIST_PATH/blastest/check-blastest.sh diff --git a/vendor/testcpp/CMakeLists.txt b/vendor/testcpp/CMakeLists.txt index 54bb8d2cb7..4e29b747ea 100644 --- a/vendor/testcpp/CMakeLists.txt +++ b/vendor/testcpp/CMakeLists.txt @@ -1,124 +1,70 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -include_directories(${CMAKE_SOURCE_DIR}/cpp) - -add_executable(test_asum_blis test_asum.cc) -target_link_libraries(test_asum_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_axpy_blis test_axpy.cc) -target_link_libraries(test_axpy_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_copy_blis test_copy.cc) -target_link_libraries(test_copy_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_dot_blis test_dot.cc) -target_link_libraries(test_dot_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_dotc_blis test_dotc.cc) -target_link_libraries(test_dotc_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_gbmv_blis test_gbmv.cc) -target_link_libraries(test_gbmv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_gemm_blis test_gemm.cc) -target_link_libraries(test_gemm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_gemv_blis test_gemv.cc) -target_link_libraries(test_gemv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_ger_blis test_ger.cc) -target_link_libraries(test_ger_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_gerc_blis test_gerc.cc) -target_link_libraries(test_gerc_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_geru_blis test_geru.cc) -target_link_libraries(test_geru_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_hemm_blis test_hemm.cc) -target_link_libraries(test_hemm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_hemv_blis test_hemv.cc) -target_link_libraries(test_hemv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_her2_blis test_her2.cc) -target_link_libraries(test_her2_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_her_blis test_her.cc) -target_link_libraries(test_her_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_herk_blis test_herk.cc) -target_link_libraries(test_herk_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_hpr2_blis test_hpr2.cc) -target_link_libraries(test_hpr2_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_hpr_blis test_hpr.cc) -target_link_libraries(test_hpr_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_nrm2_blis test_nrm2.cc) -target_link_libraries(test_nrm2_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_rot_blis test_rot.cc) -target_link_libraries(test_rot_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_rotg_blis test_rotg.cc) -target_link_libraries(test_rotg_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_rotm_blis test_rotm.cc) -target_link_libraries(test_rotm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_rotmg_blis test_rotmg.cc) -target_link_libraries(test_rotmg_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_scal_blis test_scal.cc) -target_link_libraries(test_scal_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_sdsdot_blis test_sdsdot.cc) -target_link_libraries(test_sdsdot_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_spr2_blis test_spr2.cc) -target_link_libraries(test_spr2_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_spr_blis test_spr.cc) -target_link_libraries(test_spr_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_swap_blis test_swap.cc) -target_link_libraries(test_swap_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_symm_blis test_symm.cc) -target_link_libraries(test_symm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_syr2_blis test_syr2.cc) -target_link_libraries(test_syr2_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_syr2k_blis test_syr2k.cc) -target_link_libraries(test_syr2k_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_syr_blis test_syr.cc) -target_link_libraries(test_syr_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_syrk_blis test_syrk.cc) -target_link_libraries(test_syrk_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_tbmv_blis test_tbmv.cc) -target_link_libraries(test_tbmv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_tbsv_blis test_tbsv.cc) -target_link_libraries(test_tbsv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_tpmv_blis test_tpmv.cc) -target_link_libraries(test_tpmv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_tpsv_blis test_tpsv.cc) -target_link_libraries(test_tpsv_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_trmm_blis test_trmm.cc) -target_link_libraries(test_trmm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_trsm_blis test_trsm.cc) -target_link_libraries(test_trsm_blis PRIVATE "${LIB_NAME}.lib" ) - -add_executable(test_trsv_blis test_trsv.cc) -target_link_libraries(test_trsv_blis PRIVATE "${LIB_NAME}.lib" ) +##Copyright (C) 2020 - 2023, Advanced Micro Devices, Inc. All rights reserved.## + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +if(NOT DEFINED BLIS_INSTALL_PATH) + set(DIST_PATH ${CMAKE_BINARY_DIR}) + set(LIB_PATH ${DIST_PATH}/lib/${BLIS_CONFIG_FAMILY}) + set(INC_PATH ${DIST_PATH}/include/${BLIS_CONFIG_FAMILY}) +else() + set(LIB_PATH ${BLIS_INSTALL_PATH}/lib) + set(INC_PATH ${BLIS_INSTALL_PATH}/include/blis) +endif() + +# Include the corresponding make_defs.cmake that holds the required compiler options. +include(${CMAKE_SOURCE_DIR}/config/${BLIS_CONFIG_FAMILY}/make_defs.cmake) + +# Gather all local source files. +file(GLOB testcpp_sources LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) +list(TRANSFORM testcpp_sources REPLACE ${CMAKE_CURRENT_SOURCE_DIR}/ "") + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +set(CINFLAGS ${INC_PATH}) + +# Create one executable for each of the sources. +foreach(source ${testcpp_sources}) + string(REPLACE .cc "" exec_name ${source}) + string(APPEND exec_name "_blis") + add_executable(${exec_name} ${source}) + target_compile_options(${exec_name} + PRIVATE + # load-var-for,COPTFLAGS + ${COPTFLAGS} + # get-noopt-cflags-for + ${CDBGFLAGS} + ${CWARNFLAGS} + ${CPICFLAGS} + ${CMISCFLAGS} + ${CXXLANGFLAGS} + + ) + target_include_directories(${exec_name} + BEFORE + PRIVATE + # in get-noopt-cflags-for + ${CINFLAGS} + # Add local header paths + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/vendor/cpp + ) + target_link_libraries(${exec_name} PRIVATE ${LDFLAGS} libblis) + if(THREADING_MODEL STREQUAL "openmp") + target_link_libraries(${exec_name} PRIVATE OpenMP::OpenMP_C) + endif() + set_target_properties(${exec_name} PROPERTIES CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + # Put all those targets under vendor-testcpp-targets folder name so that they appear all together in IDE. + set_target_properties(${exec_name} PROPERTIES FOLDER vendor-testcpp-targets) + add_custom_target(${exec_name}.x + COMMAND ${exec_name}) + # Put all those targets under vendor-testcpp-targets folder name so that they appear all together in IDE. + set_target_properties(${exec_name}.x PROPERTIES FOLDER vendor-testcpp-targets) + list(APPEND test_executables "${exec_name}.x") +endforeach() + +add_custom_target(checkbliscpp DEPENDS ${test_executables}) +# Put all those targets under vendor-testcpp-targets folder name so that they appear all together in IDE. +set_target_properties(checkbliscpp PROPERTIES FOLDER vendor-testcpp-targets) diff --git a/vendor/testcpp/Makefile b/vendor/testcpp/Makefile index 9a5a466f59..36b2726a2e 100644 --- a/vendor/testcpp/Makefile +++ b/vendor/testcpp/Makefile @@ -3,7 +3,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. +# Copyright (C) 2017 - 2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -142,8 +142,7 @@ LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) # all: blis - -blis: test_asum_blis.x \ +CPPEXES := test_asum_blis.x \ test_axpy_blis.x \ test_copy_blis.x \ test_dot_blis.x \ @@ -183,8 +182,10 @@ blis: test_asum_blis.x \ test_trmm_blis.x \ test_trsm_blis.x \ test_trsv_blis.x - +CPPEXES := $(addprefix $(MK_USE_LIB)/,$(CPPEXES)) + +blis: $(CPPEXES) # --Object file rules -- @@ -197,7 +198,8 @@ test_%_blis.o: test_%.cc # -- Executable file rules -- -test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) +$(MK_USE_LIB)/test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) + @mkdir -p ./$(MK_USE_LIB) @$(LINKER) $^ $(LIBBLIS_LINK) $(LDFLAGS) -o $@ ./$@ @@ -206,5 +208,5 @@ test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) clean: cleanx cleanx: - - $(RM_F) *.o *.x + - $(RM_F) ./*.o ./{shared,static}/*.x diff --git a/vendor/testcpp/test.hh b/vendor/testcpp/test.hh index b1be412d64..ccd5804332 100644 --- a/vendor/testcpp/test.hh +++ b/vendor/testcpp/test.hh @@ -138,7 +138,7 @@ int computeErrorM( for ( i = 0; i < m; i ++ ) { for ( j = 0; j < n; j ++ ) { if ( (fabs (A( i, j )) - fabs( A_ref( i, j ))) > 0.0000001 ) { - cout << A(i,j) << A_ref(i,j); + cout << A(i,j) << A_ref(i,j)<< "\n"; ret = 1; break; } diff --git a/vendor/testcpp/test_asum.cc b/vendor/testcpp/test_asum.cc index 948f4250fd..90fb98ff54 100644 --- a/vendor/testcpp/test_asum.cc +++ b/vendor/testcpp/test_asum.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP asum routine and reference blis asum routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_axpy.cc b/vendor/testcpp/test_axpy.cc index 45035198c3..a0795a6692 100644 --- a/vendor/testcpp/test_axpy.cc +++ b/vendor/testcpp/test_axpy.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_copy.cc b/vendor/testcpp/test_copy.cc index a1042d1c9b..234fe6ac39 100644 --- a/vendor/testcpp/test_copy.cc +++ b/vendor/testcpp/test_copy.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_dot.cc b/vendor/testcpp/test_dot.cc index 553287784a..668d0d6861 100644 --- a/vendor/testcpp/test_dot.cc +++ b/vendor/testcpp/test_dot.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_dotc.cc b/vendor/testcpp/test_dotc.cc index 88ffe19c4d..bac72c3026 100644 --- a/vendor/testcpp/test_dotc.cc +++ b/vendor/testcpp/test_dotc.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP dotc routine and reference blis dotc routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_gbmv.cc b/vendor/testcpp/test_gbmv.cc index 6d64f42ee3..bf76f6da60 100644 --- a/vendor/testcpp/test_gbmv.cc +++ b/vendor/testcpp/test_gbmv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_gemm.cc b/vendor/testcpp/test_gemm.cc index 2fe6e55a7c..b7a3a14a4a 100644 --- a/vendor/testcpp/test_gemm.cc +++ b/vendor/testcpp/test_gemm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_gemv.cc b/vendor/testcpp/test_gemv.cc index ca36a61d29..36bb8e0111 100644 --- a/vendor/testcpp/test_gemv.cc +++ b/vendor/testcpp/test_gemv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_ger.cc b/vendor/testcpp/test_ger.cc index 15b018ce60..1512d331e8 100644 --- a/vendor/testcpp/test_ger.cc +++ b/vendor/testcpp/test_ger.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_gerc.cc b/vendor/testcpp/test_gerc.cc index 332405b7c1..f6ac448d59 100644 --- a/vendor/testcpp/test_gerc.cc +++ b/vendor/testcpp/test_gerc.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_geru.cc b/vendor/testcpp/test_geru.cc index 03e3e6a271..a439d6c082 100644 --- a/vendor/testcpp/test_geru.cc +++ b/vendor/testcpp/test_geru.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_hemm.cc b/vendor/testcpp/test_hemm.cc index 8b88bcad35..3c4e54b930 100644 --- a/vendor/testcpp/test_hemm.cc +++ b/vendor/testcpp/test_hemm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_hemv.cc b/vendor/testcpp/test_hemv.cc index 463fdf557f..804f040428 100644 --- a/vendor/testcpp/test_hemv.cc +++ b/vendor/testcpp/test_hemv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_her.cc b/vendor/testcpp/test_her.cc index 687d1e90d8..57c3d08215 100644 --- a/vendor/testcpp/test_her.cc +++ b/vendor/testcpp/test_her.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_her2.cc b/vendor/testcpp/test_her2.cc index 2f3ca253ac..02e673c18f 100644 --- a/vendor/testcpp/test_her2.cc +++ b/vendor/testcpp/test_her2.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_herk.cc b/vendor/testcpp/test_herk.cc index 3febf3e6f1..13d2afc28f 100644 --- a/vendor/testcpp/test_herk.cc +++ b/vendor/testcpp/test_herk.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP herk routine and reference blis herk routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_hpr.cc b/vendor/testcpp/test_hpr.cc index dfc7bdd4a9..225a111b97 100644 --- a/vendor/testcpp/test_hpr.cc +++ b/vendor/testcpp/test_hpr.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_hpr2.cc b/vendor/testcpp/test_hpr2.cc index 1b8b9b2b4f..ec05ae3578 100644 --- a/vendor/testcpp/test_hpr2.cc +++ b/vendor/testcpp/test_hpr2.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_nrm2.cc b/vendor/testcpp/test_nrm2.cc index 24b96c94f2..48e46ff2fe 100644 --- a/vendor/testcpp/test_nrm2.cc +++ b/vendor/testcpp/test_nrm2.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP nrm2 routine and reference blis nrm2 routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_rot.cc b/vendor/testcpp/test_rot.cc index 8849dccb11..3e2d14a564 100644 --- a/vendor/testcpp/test_rot.cc +++ b/vendor/testcpp/test_rot.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP rot routine and reference blis rot routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_rotg.cc b/vendor/testcpp/test_rotg.cc index a99ef8c781..898e3f8d7b 100644 --- a/vendor/testcpp/test_rotg.cc +++ b/vendor/testcpp/test_rotg.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP rotg routine and reference blis rotg routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_rotm.cc b/vendor/testcpp/test_rotm.cc index 9ff793e500..2069633855 100644 --- a/vendor/testcpp/test_rotm.cc +++ b/vendor/testcpp/test_rotm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP rotm routine and reference blis rotm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_rotmg.cc b/vendor/testcpp/test_rotmg.cc index a81119b7dc..f483d670d7 100644 --- a/vendor/testcpp/test_rotmg.cc +++ b/vendor/testcpp/test_rotmg.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP rotmg routine and reference blis rotmg routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_scal.cc b/vendor/testcpp/test_scal.cc index 82b2821a66..aa9d3a1223 100644 --- a/vendor/testcpp/test_scal.cc +++ b/vendor/testcpp/test_scal.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_sdsdot.cc b/vendor/testcpp/test_sdsdot.cc index c903c97d33..295a5d42e2 100644 --- a/vendor/testcpp/test_sdsdot.cc +++ b/vendor/testcpp/test_sdsdot.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP sdsdot routine and reference blis sdsdot routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_spr.cc b/vendor/testcpp/test_spr.cc index edb7aa81a9..4da092f228 100644 --- a/vendor/testcpp/test_spr.cc +++ b/vendor/testcpp/test_spr.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_spr2.cc b/vendor/testcpp/test_spr2.cc index 24f364b8e1..97443a1aae 100644 --- a/vendor/testcpp/test_spr2.cc +++ b/vendor/testcpp/test_spr2.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_swap.cc b/vendor/testcpp/test_swap.cc index 8979d90bdf..3ba9387722 100644 --- a/vendor/testcpp/test_swap.cc +++ b/vendor/testcpp/test_swap.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_symm.cc b/vendor/testcpp/test_symm.cc index b4e10398ff..6584ba714d 100644 --- a/vendor/testcpp/test_symm.cc +++ b/vendor/testcpp/test_symm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP symm routine and reference blis symm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_syr.cc b/vendor/testcpp/test_syr.cc index 327cd93947..19a46553f6 100644 --- a/vendor/testcpp/test_syr.cc +++ b/vendor/testcpp/test_syr.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_syr2.cc b/vendor/testcpp/test_syr2.cc index 165ca146f6..5de29d3fe7 100644 --- a/vendor/testcpp/test_syr2.cc +++ b/vendor/testcpp/test_syr2.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_syr2k.cc b/vendor/testcpp/test_syr2k.cc index d56ff97a31..168308b4c6 100644 --- a/vendor/testcpp/test_syr2k.cc +++ b/vendor/testcpp/test_syr2k.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP syr2k routine and reference blis syr2k routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_syrk.cc b/vendor/testcpp/test_syrk.cc index 3defc22519..a855816cea 100644 --- a/vendor/testcpp/test_syrk.cc +++ b/vendor/testcpp/test_syrk.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_tbmv.cc b/vendor/testcpp/test_tbmv.cc index ba9d565232..01873656fe 100644 --- a/vendor/testcpp/test_tbmv.cc +++ b/vendor/testcpp/test_tbmv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_tbsv.cc b/vendor/testcpp/test_tbsv.cc index 85bcdb4ffd..40afe1575b 100644 --- a/vendor/testcpp/test_tbsv.cc +++ b/vendor/testcpp/test_tbsv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_tpmv.cc b/vendor/testcpp/test_tpmv.cc index e2a41d34aa..7c0efa0e12 100644 --- a/vendor/testcpp/test_tpmv.cc +++ b/vendor/testcpp/test_tpmv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_tpsv.cc b/vendor/testcpp/test_tpsv.cc index a9c3c2109f..81645cfcbd 100644 --- a/vendor/testcpp/test_tpsv.cc +++ b/vendor/testcpp/test_tpsv.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_trmm.cc b/vendor/testcpp/test_trmm.cc index c6301f0134..a41b384eab 100644 --- a/vendor/testcpp/test_trmm.cc +++ b/vendor/testcpp/test_trmm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP trmm routine and reference blis trmm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/testcpp/test_trsm.cc b/vendor/testcpp/test_trsm.cc index 4c5ead3bcf..7be61e25ca 100644 --- a/vendor/testcpp/test_trsm.cc +++ b/vendor/testcpp/test_trsm.cc @@ -3,7 +3,7 @@ BLISPP C++ test driver for BLIS CPP trsm routine and reference blis trsm routine. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -102,6 +102,12 @@ void test_trsm( ) allocate_init_buffer(B , m , n); copy_buffer(B, B_ref , m ,n); + // Make A diagonally dominant to guarantee that the system has a solution. + for(int i=0; i