diff --git a/.gitignore b/.gitignore index b3b811654a..539f959076 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,13 @@ out.* GPATH GRTAGS GTAGS + +# Windows Build +build/* +bin/* +*.dll +*.lib +*.pdb +*.exe + +.vscode diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d892463a7..0483435679 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -10,7 +10,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OPENMP_PATH "C:\\Program Files\\LLVM\\lib" CACHE STRING "openmp library +SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) @@ -90,9 +90,8 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) -option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) @@ -122,10 +121,6 @@ if(ENABLE_UPPERCASE_API) add_definitions(-DBLIS_ENABLE_UPPERCASE_API) endif() -if(ENABLE_API_WRAPPER) - add_definitions(-DBLIS_ENABLE_API_WRAPPER) -endif() - if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() @@ -260,7 +255,9 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) - add_compile_options(-Xclang -fopenmp) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") else() message (FATAL_ERROR "Openmp Not Found") endif() @@ -524,20 +521,34 @@ execute_process( OUTPUT_VARIABLE CMD_OUTPUT) message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) +# Logic to generate the cblas.h in include folder. +set(CBLAS_H "cblas.h") +# Arguements for python script +set(C_COMMENT "-c") +set(VERBOSE "-v1") +set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") +set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") +set(TEMP_DIR "${INCLUDE}") +set(DIR_H_PATH "${HEADER_PATH}") + +# Run python script to generate monolithic header at configuration time +execute_process( + COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CMD_OUTPUT) +message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) + # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP PATH:" ${OPENMP_PATH}) -link_directories("${OPENMP_PATH}") - if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -547,9 +558,10 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") + else() + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() link_directories(${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/LICENSE b/LICENSE index 0e7a6071d2..be24a09734 100644 --- a/LICENSE +++ b/LICENSE @@ -15,7 +15,7 @@ copyright info. All parties provide their portions of the code under the Copyright (C) 2018, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. +Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/Makefile b/Makefile index b248d5781a..1658e16de2 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -212,6 +212,27 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# AMD has optimized some of the framework files, these optimizations +# may not be compatible with other platforms. +# +# In order to keep main framework code independent of AMD changes, +# AMD has duplicated the files and updated them for example +# frame/compact/bla_gemm.c : generic framework file +# frame/compact/bla_gemm_amd.c : AMD optimized framework file +# Based on the archiecture we choose correct files + +ifeq ($(MK_IS_ARCH_ZEN),yes) +# Build is being done for AMD platforms, remove the objects which +# don't have amd suffix (for which exists AMD specific implementation). +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +FILES_TO_REMOVE := $(subst _amd.o,.o, $(MK_FRAME_AMD_OBJS)) +MK_FRAME_OBJS := $(filter-out $(FILES_TO_REMOVE), $(MK_FRAME_OBJS)) +else +# Build is done for non AMD platforms, remove the amd specific objects +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +MK_FRAME_OBJS := $(filter-out $(MK_FRAME_AMD_OBJS), $(MK_FRAME_OBJS)) +endif + # Generate object file paths for all of the debgu and trace logger. MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) @@ -1338,4 +1359,3 @@ else @echo "Uninstalling $(@F) from $(@D)/" @- $(RM_F) $@ endif - diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 6f24788aa0..6e7ee35102 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -56,6 +56,10 @@ static char *pchDTL_LOG_FILE = AOCL_DTL_LOG_FILE; /* Global file pointer for logging the results */ AOCL_FLIST_Node *gpLogFileList = NULL; + + +/* Global flag to check if logging is enabled or not */ +Bool gbIsLoggingEnabled = TRUE; #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -82,6 +86,23 @@ AOCL_FLIST_Node *gpAutoTraceFileList = NULL; void DTL_Initialize( uint32 ui32CurrentLogLevel) { + /* + * This function can be invoked multiple times either via library + * initialization function (e.g. bli_init()) or when user changes + * logging state using API. However we want it to run only once + * This flag ensure that it is executed only once. + * + * DTL can be used with many libraries hence it needs its own + * method to ensure this. + */ + + static Bool bIsDTLInitDone = FALSE; + + if (bIsDTLInitDone) + { + return; + } + /* If user selects invalid trace log level then the dafault trace log level will be AOCL_DTL_LEVEL_ALL */ if ((ui32CurrentLogLevel < 1) || (ui32CurrentLogLevel > AOCL_DTL_LEVEL_ALL)) @@ -107,15 +128,9 @@ void DTL_Initialize( #endif #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) - /* Create/Open the file to log the log data */ - AOCL_FLIST_AddFile(pchDTL_LOG_FILE, &gpLogFileList, AOCL_gettid()); - - if (NULL == gpLogFileList) - { - /* Unable to open the specified file.*/ - AOCL_DEBUGPRINT("Unable to create the log file %s\n", pchDTL_LOG_FILE); - return; - } + + /* Check if DTL logging is requested via envoronment variable */ + gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", TRUE ); #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -133,6 +148,9 @@ void DTL_Initialize( /* Save Id for main thread */ gtidMainThreadID = AOCL_gettid(); + // Ensure that this function is executed only once + bIsDTLInitDone = TRUE; + } /* DTL_Initialize */ #endif @@ -193,6 +211,19 @@ void DTL_Trace( { uint8 i = 0; AOCL_FAL_FILE *pOutFile = NULL; + +#if AOCL_DTL_LOG_ENABLE + /* + * For performance reasons we check the logging state in end user + * macros, this is just an additional check in case the function + * is invoked from any other context. + */ + if (gbIsLoggingEnabled == FALSE && ui8LogType == TRACE_TYPE_LOG) + { + return; + } +#endif + uint64 u64EventTime = AOCL_getTimestamp(); dim_t u64RequestedThreadsCount = AOCL_get_requested_threads_count(); diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 58c1a56079..f520518e9c 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -5,7 +5,7 @@ * It provides defination for all macros to be * used by user to add debug/trace information. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -15,6 +15,7 @@ #include "aocldtlcf.h" #include "aocltpdef.h" #include "aoclflist.h" +#include "aoclos.h" #define TRACE_TYPE_FENTRY (1) #define TRACE_TYPE_FEXIT (2) @@ -108,6 +109,31 @@ void AOCL_DTL_start_perf_timer(void); uint64 AOCL_DTL_get_time_spent(void); +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + */ +extern Bool gbIsLoggingEnabled; + +/* API to enable logging at runtime */ +#define AOCL_DTL_Enable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = TRUE; + +/* API to disable logging at runtime */ +#define AOCL_DTL_Disable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = FALSE; + /* Macro to log the Data */ #define AOCL_DTL_START_PERF_TIMER() \ AOCL_DTL_start_perf_timer() diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h index a9ea3368f9..7b352f9d43 100755 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -385,115 +385,148 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) \ - AOCL_DTL_log_gemm_stats(loglevel, m, n, k); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_stats(loglevel, m, n, k); #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMM_INPUTS(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); // Level-3 Macros #define AOCL_DTL_LOG_HERK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc)\ - AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SYMM_INPUTS(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); // Level-2 Macros #define AOCL_DTL_LOG_GEMV_INPUTS(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy) \ - AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GER_INPUTS(loglevel, dt_type, m, n, alpha, incx, incy, lda) \ - AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda )\ - AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy)\ - AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_COPY_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SCAL_INPUTS(loglevel, dt_type, alpha, n, incx )\ - AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SWAP_INPUTS(loglevel, dt_type, n, incx, incy)\ - AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_NRM2_INPUTS(loglevel, dt_type, n, incx)\ - AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy) \ - AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ + __FILE__, __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_AMAX_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_ASUM_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPBY_INPUTS(loglevel, dt_type, n, alpha, incx, beta, incy) \ - AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPY_INPUTS(loglevel, dt_type, n, alpha, incx, incy) \ - AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ #define AOCL_DTL_LOG_SYR2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ - ldc, __FILE__, __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ + ldc, __FILE__, __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda) \ - AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYRK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc) \ - AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMM_INPUTS(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx) \ - AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRSV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx ) \ - AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #else #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 4f1e923a05..1f44f54405 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -5,7 +5,7 @@ * libaray, all debug features (except auto trace) * can be enabled/disabled in this file. * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -20,8 +20,8 @@ enable this macro by making it to 1 else 0 */ #define AOCL_DTL_DUMP_ENABLE 0 -/* Macro for logging the logs If the user wants to enable loging information he - has to enable this macro by making it to 1 else 0 */ +/* Macro for dumping the log If the user wants to enable input logs he has to + enable this macro by making it to 1 else 0 */ #define AOCL_DTL_LOG_ENABLE 0 /* Select the trace level till which you want to log the data */ diff --git a/aocl_dtl/aoclos.c b/aocl_dtl/aoclos.c index 92a489564e..896b1c89b3 100644 --- a/aocl_dtl/aoclos.c +++ b/aocl_dtl/aoclos.c @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ #include "aocltpdef.h" @@ -85,8 +85,15 @@ uint64 AOCL_getTimestamp(void) #else /* Non linux support */ AOCL_TID AOCL_gettid(void) { - /* stub for other os's */ - return 0; +#ifdef BLIS_ENABLE_OPENMP + return omp_get_thread_num(); +#else +#ifdef BLIS_ENABLE_PTHREADS + return pthread_self(); +#else + return 0; +#endif +#endif } pid_t AOCL_getpid(void) diff --git a/aocl_dtl/aoclos.h b/aocl_dtl/aoclos.h index 3d8e1cddcc..57e0c24902 100644 --- a/aocl_dtl/aoclos.h +++ b/aocl_dtl/aoclos.h @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -19,7 +19,7 @@ #define AOCL_malloc malloc #define AOCL_free free -uint32 AOCL_gettid(void); +AOCL_TID AOCL_gettid(void); pid_t AOCL_getpid(void); uint64 AOCL_getTimestamp(void); diff --git a/bench/Makefile b/bench/Makefile index 3ee497212d..d47485b2fc 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -191,7 +191,8 @@ blis: \ bench_trsv_blis.x \ bench_amaxv_blis.x \ bench_copyv_blis.x \ - bench_swapv_blis.x + bench_swapv_blis.x \ + bench_axpbyv_blis.x openblas: \ bench_gemm_openblas.x \ @@ -205,7 +206,8 @@ openblas: \ bench_trsv_openblas.x \ bench_amaxv_openblas.x \ bench_copyv_openblas.x \ - bench_swapv_openblas.x + bench_swapv_openblas.x \ + bench_axpbyv_openblas.x atlas: \ bench_gemm_atlas.x \ @@ -219,7 +221,8 @@ atlas: \ bench_trsv_atlas.x \ bench_amaxv_atlas.x \ bench_copyv_atlas.x \ - bench_swapv_atlas.x + bench_swapv_atlas.x \ + bench_axpbyv_atlax.x mkl: \ bench_gemm_mkl.x \ @@ -233,7 +236,8 @@ mkl: \ bench_trsv_mkl.x \ bench_amaxv_mkl.x \ bench_copyv_mkl.x \ - bench_swapv_mkl.x + bench_swapv_mkl.x \ + bench_axpbyv_mkl.x # --Object file rules -- diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c new file mode 100644 index 0000000000..36a203f696 --- /dev/null +++ b/bench/bench_axpbyv.c @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +#ifndef DT +#define DT BLIS_DOUBLE +#endif +#define AOCL_MATRIX_INITIALISATION + +int main( int argc, char** argv ) +{ + obj_t x, y, alpha, beta; // BLIS objects + dim_t p_inc = 0; // To keep track of number of inputs + num_t dt; // BLIS datatype + char dt_ch; // {S, D, Z, C} from input + int r, n_repeats; // repetition counter; number of repeats + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; // Input FILE* + FILE* fout = NULL; // Output FILE* + + n_repeats = N_REPEAT; // Fetched from Makefile + + dt = DT; // Set datatype as BLIS_DOUBLE + + if ( argc < 3 ) + { + printf( "Usage: ./bench_axpbyv_XX.x input.txt output.txt\n" ); + exit( 1 ); + } + + fin = fopen( argv[1], "r" ); // Open input file in read mode + if ( fin == NULL ) + { + printf( "Error opening input file %s\n", argv[1] ); + exit( 1 ); + } + + fout = fopen( argv[2], "w" ); // Open output file in write mode + if ( fout == NULL ) + { + printf( "Error opening output file %s\n", argv[2] ); + exit( 1 ); + } + +#ifdef DEBUG + fprintf( fout, "gflops\n" ); +#else + fprintf(fout, "Dt\t n\t alpha_r\t alpha_i\t beta_r\t beta_i\t gflops\n" ); +#endif + + dim_t n; // dimension + inc_t incx; // stride x + inc_t incy; // stride y + char tmp[256]; // to store function name, line not present in logs + double alpha_r, alpha_i, beta_r, beta_i; + + // {function name} {S, D, C, Z} {n} + // {alpha_r} {alpha_i} {incx} {beta_r} {beta_i} {incy} + while ( fscanf( fin, "%s %c %ld %lf %lf %ld %lf %lf %ld\n", + tmp, &dt_ch, &n, + &alpha_r, &alpha_i, &incx, &beta_r, &beta_i, &incy ) == 9 ) + { + if ( dt_ch == 'D' || dt_ch == 'd' ) dt = BLIS_DOUBLE; + else if ( dt_ch == 'Z' || dt_ch == 'z' ) dt = BLIS_DCOMPLEX; + else if ( dt_ch == 'S' || dt_ch == 's' ) dt = BLIS_FLOAT; + else if ( dt_ch == 'C' || dt_ch == 'c' ) dt = BLIS_SCOMPLEX; + else + { + printf( "Invalid data type %c\n", dt_ch ); + continue; + } + + // Creating BLIS objects + bli_obj_create( dt, n, 1, incx, 1, &x ); // For input vector x + bli_obj_create( dt, n, 1, incy, 1, &y ); // For input vector y + bli_obj_create( dt, 1, 1, 0, 0, &alpha); // For input vector alpha + bli_obj_create( dt, 1, 1, 0, 0, &beta); // For input vector beta + + #ifdef AOCL_MATRIX_INITIALISATION + bli_randm( &x ); + bli_randm( &y ); + #endif + + bli_setsc( alpha_r, alpha_i, &alpha ); + bli_setsc( beta_r, beta_i, &beta ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + dtime = bli_clock(); + +#ifdef BLIS + bli_axpbyv( &alpha, &x, &beta, &y ); +#else + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + f77_int blas_incy = bli_obj_vector_inc( &y ); + + if ( bli_is_float( dt ) ) + { + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* betap = bli_obj_buffer( &beta ); + float* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_saxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + saxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_double( dt ) ) + { + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* betap = bli_obj_buffer( &beta ); + double* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_daxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + daxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_caxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + caxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_zaxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + zaxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_axpbyv_%s", BLAS ); + + p_inc++; + printf( " %4lu [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops ); + + fprintf( fout, "%c\t %ld\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", + dt_ch, n, alpha_r, alpha_i, beta_r, beta_i, gflops ); + fflush( fout ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + return 0; +} \ No newline at end of file diff --git a/bench/bench_gemmt.c b/bench/bench_gemmt.c index aef194135b..621c9288c7 100644 --- a/bench/bench_gemmt.c +++ b/bench/bench_gemmt.c @@ -107,7 +107,7 @@ int main( int argc, char** argv ) printf("Error opening output file %s\n", argv[2]); exit(1); } - fprintf(fout, "Dt uplo n\t k\t lda\t ldb\t ldc\t transa transb alphaR\t alphaI\t betaR\t betaI\t gflops\n"); + fprintf(fout, "Dt\t uplo\t n\t k\t lda\t ldb\t ldc\t transa\t transb\t alphaR\t alphaI\t betaR\t betaI\t gflops\n"); inc_t lda; @@ -455,7 +455,7 @@ int main( int argc, char** argv ) if ( bli_is_complex( dt ) ) gflops *= 4.0; - printf("data_gemm_%s", BLAS); + printf("data_gemmt_%s", BLAS); p_inc++; printf( "( %2lu, 1:4 ) = [ %4lu %4lu %7.2f ];\n", @@ -463,7 +463,7 @@ int main( int argc, char** argv ) ( unsigned long )n, ( unsigned long )k, gflops ); - fprintf(fout, "%c %c %ld\t %ld\t %ld\t %ld\t %ld\t %c %c %lf\t %lf\t %lf\t %lf\t %6.3f\n", \ + fprintf(fout, "%c\t %c\t %ld\t %ld\t %ld\t %ld\t %ld\t %c\t %c\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", \ dt_ch, uplo_c, n, k, lda, ldb, ldc, transA_c, transB_c, diff --git a/bench/bench_ger.c b/bench/bench_ger.c index f6e5b27f59..fb50c94265 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -66,7 +66,6 @@ int main( int argc, char** argv ) dim_t p_inc = 0; // to keep track of number of inputs num_t dt; char dt_ch; - char stor_scheme; int r, n_repeats; double dtime; @@ -76,6 +75,10 @@ int main( int argc, char** argv ) FILE* fin = NULL; FILE* fout = NULL; +#ifdef CBLAS + char stor_scheme; +#endif + n_repeats = N_REPEAT; // This macro will get from Makefile. dt = DT; @@ -108,7 +111,9 @@ int main( int argc, char** argv ) inc_t incy; char tmp[256]; // to store function name, line no present in logs. +#ifdef CBLAS stor_scheme = 'C'; +#endif // {S,D,C,Z} {transa m n alpha incx incy lda} diff --git a/bench/inputaxpbyv.txt b/bench/inputaxpbyv.txt new file mode 100644 index 0000000000..3cfc7ae732 --- /dev/null +++ b/bench/inputaxpbyv.txt @@ -0,0 +1,40 @@ +saxpbyv_ S 32 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 64 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 100 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 300 1.100000 0.000000 1 1.100000 0.000000 1 +saxpbyv_ S 400 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 500 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 1000 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 10000 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 32 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 64 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 100 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 300 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 400 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 500 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 1000 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 10000 1.100000 0.000000 1 1.100000 0.000000 1 +caxpbyv_ C 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 64 1.000000 1.100000 1 1.000000 1.100000 1 +caxpbyv_ C 100 -1 1.000000 1 -1 1 1 +caxpbyv_ C 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +caxpbyv_ C 300 1.100000 1.000000 1 1.100000 1 1 +caxpbyv_ C 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 500 1.000000 1.000000 1 1.000000 1 1 +caxpbyv_ C 1000 -1 0.900000 1 -1 0.900000 1 +caxpbyv_ C 5000 -1.100000 -1 1 -1.100000 -1 1 +caxpbyv_ C 10000 1.100000 -1 1 1.100000 -1 1 +zaxpbyv_ Z 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 64 1.000000 1.100000 1 1.000000 1.100000 1 +zaxpbyv_ Z 100 -1 1.000000 1 -1 1 1 +zaxpbyv_ Z 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +zaxpbyv_ Z 300 1.100000 1.000000 1 1.100000 1 1 +zaxpbyv_ Z 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 500 1.000000 1.000000 1 1.000000 1 1 +zaxpbyv_ Z 1000 -1 0.900000 1 -1 0.900000 1 +zaxpbyv_ Z 5000 -1.100000 -1 1 -1.100000 -1 1 +zaxpbyv_ Z 10000 1.100000 -1 1 1.100000 -1 1 diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index b756eb30b6..8ef90a12af 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -68,11 +68,13 @@ def remove_lines_in_file(filename): with open(filename, 'r') as fd: file_content = fd.read() file_content = file_content.replace( - 'if(${TARGET_ARCH} STREQUAL amdzen)\nadd_subdirectory(${CMAKE_BINARY_' - 'DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_' - 'DIR}/ref_kernels/zen)\nadd_subdirectory(${CMAKE_BINARY_DIR}/' - 'ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' + 'if(${TARGET_ARCH} STREQUAL amdzen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ' + '${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\nelse()', '\n') data = file_content.replace('endif()', '\n') @@ -111,6 +113,7 @@ def add_macro_to_cfiles(cfiles, macro): create_folder(os.path.join(dest_path, 'zen')) create_folder(os.path.join(dest_path, 'zen2')) create_folder(os.path.join(dest_path, 'zen3')) + create_folder(os.path.join(dest_path, 'zen4')) create_folder(os.path.join(dest_path, 'generic')) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen'))) @@ -118,6 +121,8 @@ def add_macro_to_cfiles(cfiles, macro): temp, os.path.join(dest_path, 'zen2'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen3'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen4'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'generic'))) remove_folder(temp) @@ -129,6 +134,8 @@ def add_macro_to_cfiles(cfiles, macro): dest_path, 'zen2', 'CMakeLists.txt')) remove_lines_in_file(os.path.join( dest_path, 'zen3', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen4', 'CMakeLists.txt')) cfiles_in_generic = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'generic'))) @@ -136,20 +143,22 @@ def add_macro_to_cfiles(cfiles, macro): add_macro_to_cfiles(cfiles_in_generic, '\n#define BLIS_CNAME_INFIX _generic\n') cfiles_in_zen = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen'))) + .format(os.path.join(dest_path, 'zen'))) cfiles_in_zen = cfiles_in_zen.split('\r\n') add_macro_to_cfiles(cfiles_in_zen, '\n#define BLIS_CNAME_INFIX _zen\n') cfiles_in_zen2 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen2'))) + .format(os.path.join(dest_path, 'zen2'))) cfiles_in_zen2 = cfiles_in_zen2.split('\r\n') add_macro_to_cfiles(cfiles_in_zen2, '\n#define BLIS_CNAME_INFIX _zen2\n') cfiles_in_zen3 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen3'))) + .format(os.path.join(dest_path, 'zen3'))) cfiles_in_zen3 = cfiles_in_zen3.split('\r\n') add_macro_to_cfiles(cfiles_in_zen3, '\n#define BLIS_CNAME_INFIX _zen3\n') + cfiles_in_zen4 = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, 'zen4'))) + cfiles_in_zen4 = cfiles_in_zen4.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen4, + '\n#define BLIS_CNAME_INFIX _zen4\n') diff --git a/build/config.mk.in b/build/config.mk.in index 709e0f543c..a880074e8f 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -204,5 +204,7 @@ MK_ENABLE_AOCL_DYNAMIC := @enable_aocl_dynamic@ # BLAS int size MK_BLAS_INT_TYPE_SIZE := @blas_int_type_size@ +MK_IS_ARCH_ZEN := @enable_aocl_zen@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/config/amdzen/make_defs.mk b/config/amdzen/make_defs.mk index 7697e9ff05..e467461601 100644 --- a/config/amdzen/make_defs.mk +++ b/config/amdzen/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,16 +49,6 @@ else COPTFLAGS := -O3 endif -# This will add BLIS_CONFIG_EPYC for all framework files -# FIXME: framework files should not have architecture specific -# checks at least at compile time. Once the macro -# is defined it is applicable to every build in the -# Family including any non AMD configuration. -# However, it is still better to define it in makefiles -# instead of headers so we can have slighly more -# control on this. -COPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/generic/make_defs.mk b/config/generic/make_defs.mk index ee77b6cf0e..4ce2fac758 100644 --- a/config/generic/make_defs.mk +++ b/config/generic/make_defs.mk @@ -79,10 +79,10 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else ifeq ($(CC_VENDOR),clang) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 7595849866..3fea3ea8f9 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -80,27 +80,41 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + //axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, @@ -121,6 +135,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv #if 0 BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index be1086a1de..b4153fcbfb 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -46,25 +46,12 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) - -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC - - ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else COPTFLAGS := -O3 endif - # # --- Enable ETRACE across the library if enabled ETRACE_ENABLE=[0,1] ----------------------- # @@ -81,15 +68,11 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 4f56316a7a..1ecb62ff52 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -92,28 +92,40 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, @@ -130,6 +142,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index ba91f722ab..3b87d35b00 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen2 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -111,10 +103,6 @@ endif CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index fc7dbcb808..02e264d277 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -92,28 +92,40 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, @@ -130,6 +142,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index a479acf8a5..8522a1e956 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen3 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -132,10 +124,6 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/configure b/configure index bec498d3cf..f49ea19e5e 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -3370,6 +3370,7 @@ main() | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ | sed -e "s/@complex_return@/${complex_return}/g" \ | sed -e "s/@blas_int_type_size@/${blas_int_type_size}/g" \ + | sed -e "s/\@enable_aocl_zen\@/${enable_aocl_zen}/g" \ > "${config_mk_out_path}" diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 87f8df4f7d..c720317b96 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -159,7 +159,8 @@ void bli_packm_blk_var1 // Treatment of kappa (ie: packing during scaling) depends on // whether we are executing an induced method. - if ( bli_is_nat_packed( schema ) ) + // For dzgemm, scale alpha during packing. + if ( bli_is_nat_packed( schema ) && cntl && bli_cntl_family(cntl) != BLIS_GEMM_MD) { // This branch is for native execution, where we assume that // the micro-kernel will always apply the alpha scalar of the diff --git a/frame/2/gemv/CMakeLists.txt b/frame/2/gemv/CMakeLists.txt index 86be8ddc08..2f75a00f63 100644 --- a/frame/2/gemv/CMakeLists.txt +++ b/frame/2/gemv/CMakeLists.txt @@ -1,11 +1,25 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c + ) +endif() diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 4f0054c1f1..8162613c18 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,7 +34,6 @@ */ #include "blis.h" -#define BLIS_DGEMV_VAR1_FUSE 8 #undef GENTFUNC #define GENTFUNC( ctype, ch, varname ) \ @@ -105,285 +104,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - double* A1; - double* y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; - - bli_init_once(); - - if( cntx == NULL ) cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - PASTECH(d,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - if (incx > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - mem_bufX.pblk.buf = NULL; mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; mem_bufX.size = 0; - mem_bufX.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): get mem pool block\n" ); - #endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); - - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for(dim_t x_index = 0 ; x_index < n_elem ; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)) ; - } - // stride of vector x_buf =1 - buf_incx = 1; - } - } - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR1_FUSE ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_ddotxf_zen_int_8 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx - ); - - } - if ((incx > 1) && bli_mem_is_alloc( &mem_bufX )) - { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufX); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_init_once(); - - if( cntx == NULL ) cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - float* x1 ; - PASTECH(s,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 8; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_sdotxf_zen_int_8 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } -} - -INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( gemv_unf_var1 ) -#endif + diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c new file mode 100644 index 0000000000..447f8dbc43 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -0,0 +1,572 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_iter, &n_elem, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = beta * y1 + alpha * A1 * x; */ \ + kfp_df \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, cs_at, rs_at, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ +\ + } \ +} + +void bli_dgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector X. + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; + + bli_init_once(); + + if (cntx == NULL) + cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); + + conja = bli_extract_conj(transa); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + PASTECH(d,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + if (incx > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); + +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); + + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; + } + } + + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) + { + case 8: + + bli_ddotxf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + + break; + default: + + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + } + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) + { + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } + } + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +// Returns the optimal number of threads for the given input sizes and fuse factor +void bli_sgemv_var1_smart_threading + ( + dim_t m, dim_t n, + dim_t fuse, + dim_t* nt, dim_t nt_max + ) +{ + // Calculate the amount data processed per iteration + dim_t n_per_loop = n / fuse; + double data_per_iter = n_per_loop* m; + double m_n_ratio = m/n; + + // When the input value is less than the fuse factor + if(n_per_loop < 1) + { + *nt = 1; + return; + } + + // Then there are two cases one + // In m < n the thread spawning is less aggressive when compared to m > n and m = n cases + if(m_n_ratio <= 0.6) + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const double lower_boundary = 50000; + const double higher_boundary = 500000; + + if(data_per_iter < lower_boundary) + { + double coeff_x = 0.9148; + double constant = -1.6252; + // Number of threads = 0.9148 * log(x) - 1.6252 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 10.23; + float constant = -82.332; + // Number of threads = 10.23 * log(x) - 82.332 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + else + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const float lower_boundary = 50000; + const float higher_boundary = 360000; + + if(data_per_iter < lower_boundary) + { + float coeff_x2 = -2E-09; + float coeff_x = 0.0002; + float constant = 1.0234; + // Number of threads = -2E-09*x^2 + 0.0002 * x + 1.0234 + *nt = ceil(coeff_x2 * (data_per_iter * data_per_iter) + coeff_x * data_per_iter + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 16.917; + float constant = -164.82; + // Number of threads = 16.917 * log(x) - 164.82 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + + + // When the number of threads calculated is greater than the user provided value + // Choose the user provided value + if(*nt > nt_max ) *nt = nt_max; + if(*nt <=0 ) *nt = 1; +} + +void bli_sgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + float* A1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_init_once(); + + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(s,type); + float* x1 ; + PASTECH(s,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + +// If both multithreading and OpenMP are enabled, GEMV will multithread +#if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) + bool is_omp_mt_enabled = TRUE; +#else + bool is_omp_mt_enabled = FALSE; +#endif + + dim_t nt_max; + + rntm_t rnmt_obj; + // Initialize a local runtime with global settings. + bli_rntm_init_from_global( &rnmt_obj ); + + // Query the total number of threads from the rntm_t object. + nt_max = bli_rntm_num_threads( &rnmt_obj ); + + if ( ( nt_max > 1 ) & ( is_omp_mt_enabled == TRUE ) ) + { + b_fuse = 4; + + //Setting the thread count to the maximum number of threads provided + dim_t nt = nt_max; + + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif + + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 + ( + conja, + conjx, + n_elem, + n_iter, + alpha, + a, cs_at, rs_at, + x, incx, + beta, + y, incy, + cntx, + nt + ); + } + else + { + b_fuse = 8; + + for ( i = 0; i < n_iter; i += f ) + { + float* x1; + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + } + } +} + +INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) + diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 84a67c3189..d6c21de6df 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -137,752 +137,4 @@ void PASTEMAC(ch,varname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_dgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - double* A1; - double* x1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector Y. - mem_t mem_bufY; - rntm_t rntm; - double *y_buf = y; - inc_t buf_incy = incy; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - double* zero = PASTEMAC(d,0); - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_deq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - if (incy > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; - mem_bufY.buf_type = 0; mem_bufY.size = 0; - mem_bufY.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - //calculate the size required for n_elem double elements in vector Y. - size_t buffer_size = n_elem * sizeof(double); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); - #endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufY.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufY); - - /*Continue packing Y if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufY ))) - { - y_buf = bli_mem_buffer(&mem_bufY); - - //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y_buf + y_index) = *(y + (y_index * incy)) ; - } - // stride of vector y_buf =1 - buf_incy = 1; - } - } - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - NULL - ); - } - if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) - { - //store the result from unit strided y_buf to non-unit strided Y - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y + (y_index * incy)) = *(y_buf + y_index) ; - } - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufY); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(s,eq0)( *beta ) ) - { - float* zero = PASTEMAC(s,0); - /* y = 0; */ - PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(s,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_seq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 6; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_saxpyf_zen_int_6 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -void bli_zgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - dcomplex* beta, - dcomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - dcomplex* A1; - dcomplex* x1; - dcomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - - /* beta=0 case is hadled by scalv internally */ - /* bli_zscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(z,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(z,eq0)( *beta ) ) - { - dcomplex* zero = PASTEMAC(z,0); - /* y = 0; */ - PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(z,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_zscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - if( bli_zeq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( (incx == 1 && incy == 1 && rs_at == 1 ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_zgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_zaxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_cgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - scomplex* beta, - scomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - scomplex* A1; - scomplex* x1; - scomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - /*bli_cscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(c,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(c,eq0)( *beta ) ) - { - scomplex* zero = PASTEMAC(c,0); - /* y = 0; */ - PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(c,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_cscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - - - if( bli_ceq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && - !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_cgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor. */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_caxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -#else INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) -#endif diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c new file mode 100644 index 0000000000..831d906ca4 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -0,0 +1,939 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_DGEMV_VAR2_FUSE 4 + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); \ +\ + bli_init_once(); \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + kfp_af \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ +} + +void bli_dgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + double* A1; + double* x1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector Y. + mem_t mem_bufY; + rntm_t rntm; + double *y_buf = y; + inc_t buf_incy = incy; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + double* zero = PASTEMAC(d,0); + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_deq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + if (incy > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; + mem_bufY.buf_type = 0; mem_bufY.size = 0; + mem_bufY.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + //calculate the size required for n_elem double elements in vector Y. + size_t buffer_size = n_elem * sizeof(double); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); + #endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufY.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufY); + + /*Continue packing Y if buffer memory is allocated*/ + if ((bli_mem_is_alloc( &mem_bufY ))) + { + y_buf = bli_mem_buffer(&mem_bufY); + + //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y_buf + y_index) = *(y + (y_index * incy)) ; + } + // stride of vector y_buf =1 + buf_incy = 1; + } + } + + dim_t fuse_factor = 8; + dim_t f_temp = 0; + + // Change the fuse factor based on + // Input size and available kernels + // This ensures that fusing is possible when the number of + // left over colums is less (better problem decomposition) + if (n < 5) fuse_factor = 4; + else if (n < 8) fuse_factor = 5; + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + A1 = a + (i)*cs_at; + x1 = x + (i)*incx; + + // Pick kernel based on problem size + switch (f) + { + case 8: + + bli_daxpyf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + + break; + default: + + if (f < 5) + { + bli_daxpyf_zen_int_16x4( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + else + { + bli_daxpyf_zen_int_5( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + } + + // Calculate the next problem size + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + // Change fuse factor based on the next problem size + if (f_temp < fuse_factor) + { + if (f_temp < 5) + { + fuse_factor = 4; + } + else + { + fuse_factor = 5; + } + } + } + + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) + { + //store the result from unit strided y_buf to non-unit strided Y + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y + (y_index * incy)) = *(y_buf + y_index) ; + } + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool + bli_membrk_release(&rntm , &mem_bufY); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_sgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(s,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(s,eq0)( *beta ) ) + { + float* zero = PASTEMAC(s,0); + /* y = 0; */ + PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(s,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_seq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 6; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_saxpyf_zen_int_6 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + +void bli_zgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + dcomplex* beta, + dcomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + dcomplex* A1; + dcomplex* x1; + dcomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + + /* beta=0 case is hadled by scalv internally */ + /* bli_zscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(z,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + dcomplex* zero = PASTEMAC(z,0); + /* y = 0; */ + PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(z,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_zscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + if( bli_zeq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( (incx == 1 && incy == 1 && rs_at == 1 ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_zgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_zaxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_cgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + scomplex* beta, + scomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + scomplex* A1; + scomplex* x1; + scomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + /*bli_cscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(c,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + scomplex* zero = PASTEMAC(c,0); + /* y = 0; */ + PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(c,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_cscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + + + if( bli_ceq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && + !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_cgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor. */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + + diff --git a/frame/2/hemv/CMakeLists.txt b/frame/2/hemv/CMakeLists.txt index 677c253271..34820c3762 100644 --- a/frame/2/hemv/CMakeLists.txt +++ b/frame/2/hemv/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,10 +6,25 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1a.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c + ) +endif() \ No newline at end of file diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index d36dc00988..e3229543c0 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/2/hemv/bli_hemv_unf_var1_amd.c b/frame/2/hemv/bli_hemv_unf_var1_amd.c new file mode 100644 index 0000000000..6532323d11 --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var1_amd.c @@ -0,0 +1,418 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A10; \ + ctype* A11; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* chi11; \ + ctype* y0; \ + ctype* y1; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_behind; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_behind = i; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + x0 = x + (0 )*incx; \ + x1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ \ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_behind, \ + f, \ + alpha, \ + A10, cs_at, rs_at, \ + x0, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y0, incy, \ + cntx \ + ); \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ + } \ +} + +void bli_post_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var1 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A10; + double* A11; + double* a10t; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* chi11; + double* y0; + double* y1; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_behind; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + /* Query the context for the kernel function pointer and fusing + * factor. */ + /* Assign kernel function pointer and fusing factor. */ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_behind = i; + A10 = a + (i )*rs_at + (0 )*cs_at; + A11 = a + (i )*rs_at + (i )*cs_at; + x0 = x + (0 )*incx; + x1 = x + (i )*incx; + y0 = y + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_behind, + f, + alpha, + A10, cs_at, rs_at, + x0, incx, + x1, incx, + one, + y1, incy, + y0, incy, + cntx + ); + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, *chi11, + conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,axpys)( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + } +} +GENTFUNC(float, s, hemv_unf_var1) +GENTFUNC(scomplex, c, hemv_unf_var1) +GENTFUNC(dcomplex, z, hemv_unf_var1) + + diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index d8db9bc78a..b8e26cbcb6 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -217,3 +218,4 @@ void PASTEMAC(ch,varname) \ INSERT_GENTFUNC_BASIC0( hemv_unf_var3 ) + diff --git a/frame/2/hemv/bli_hemv_unf_var3_amd.c b/frame/2/hemv/bli_hemv_unf_var3_amd.c new file mode 100644 index 0000000000..34d40cf5cc --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var3_amd.c @@ -0,0 +1,420 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A11; \ + ctype* A21; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x1; \ + ctype* x2; \ + ctype* chi11; \ + ctype* y1; \ + ctype* y2; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_ahead; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_ahead = m - i - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ + y1 = y + (i )*incy; \ + y2 = y + (i+f)*incy; \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ +\ + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ \ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_ahead, \ + f, \ + alpha, \ + A21, rs_at, cs_at, \ + x2, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y2, incy, \ + cntx \ + ); \ + } \ +} + +void bli_pre_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var3 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A11; + double* A21; + double* a10t; + double* alpha11; + double* a21; + double* x1; + double* x2; + double* chi11; + double* y1; + double* y2; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_ahead; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_ahead = m - i - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + y1 = y + (i )*incy; + y2 = y + (i+f)*incy; + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, + *chi11, conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + { + for ( j = 0; j < f_behind; ++j ) + { + PASTEMAC(d,axpys) + ( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + } + } + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_ahead, + f, + alpha, + A21, rs_at, cs_at, + x2, incx, + x1, incx, + one, + y1, incy, + y2, incy, + cntx + ); + } +} + +GENTFUNC(float, s, hemv_unf_var3) +GENTFUNC(scomplex, c, hemv_unf_var3) +GENTFUNC(dcomplex, z, hemv_unf_var3) + + diff --git a/frame/2/her2/CMakeLists.txt b/frame/2/her2/CMakeLists.txt index 1b4c264443..83629df8f5 100644 --- a/frame/2/her2/CMakeLists.txt +++ b/frame/2/her2/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,8 +6,23 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c + ) +endif() \ No newline at end of file diff --git a/frame/2/her2/bli_her2_unf_var1_amd.c b/frame/2/her2/bli_her2_unf_var1_amd.c new file mode 100644 index 0000000000..31667cc3e4 --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var1_amd.c @@ -0,0 +1,389 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* x0; \ + ctype* chi1; \ + ctype* y0; \ + ctype* psi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_chi1; \ + ctype alpha1_psi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjx0_chi1; \ + ctype conjy1_psi1; \ + ctype conjy0_psi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjy ); \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + psi1 = y + (i )*incy; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjx, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conjy, *psi1, conjy1_psi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *psi1, conjy0_psi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_chi1, conjy0_psi1, alpha0_chi1_psi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * y0'; */ \ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_behind, \ + &alpha0_chi1, \ + &alpha1_psi1, \ + y0, incy, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, + &alpha0, + n_behind + 1, + cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 + * and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, + conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, + conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, + conjy0_psi1 ); + + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, + conjx0_chi1, + alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, + conjy1_psi1, + alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, + conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0';*/ + /* c10t = c10t + conj(alpha) * psi1 * x0';*/ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 *conj(psi1) + * + conj(alpha) * psi1 * conj(chi1);*/ + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) + + diff --git a/frame/2/her2/bli_her2_unf_var4_amd.c b/frame/2/her2/bli_her2_unf_var4_amd.c new file mode 100644 index 0000000000..6e999be7d1 --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var4_amd.c @@ -0,0 +1,365 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* chi1; \ + ctype* x2; \ + ctype* psi1; \ + ctype* y2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_psi1; \ + ctype alpha1_chi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjy0_psi1; \ + ctype conjx1_chi1; \ + ctype conjx0_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ + conj_t conjh_conjx; \ + conj_t conjh_conjy; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conjh_conjx; \ + ( void )conjh_conjy; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = conjx; \ + conj1 = conjy; \ + conjh_conjx = bli_apply_conj( conjh, conjx ); \ + conjh_conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + psi1 = y + (i )*incy; \ + y2 = y + (i+1)*incy; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjh_conjy, *psi1, conjy0_psi1 ); \ + PASTEMAC(ch,copycjs)( conjh_conjx, *chi1, conjx1_chi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjy0_psi1, alpha0_psi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjx1_chi1, alpha1_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_psi1, conjx0_chi1, alpha0_chi1_psi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(psi1); */ \ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_ahead, \ + &alpha0_psi1, \ + &alpha1_chi1, \ + x2, incx, \ + y2, incy, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, + psi1, &alpha0, + n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, + alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, + alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed, + * by conjx and conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1)*/ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) + + diff --git a/frame/2/trsv/CMakeLists.txt b/frame/2/trsv/CMakeLists.txt index 1d16769d32..b07389340e 100644 --- a/frame/2/trsv/CMakeLists.txt +++ b/frame/2/trsv/CMakeLists.txt @@ -1,11 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c + ) +endif() diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index 4f19e1ac5e..55e28a4417 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -231,413 +231,4 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - - double* one = PASTEMAC(d,1); - double* minus_one = PASTEMAC(d,m1); - double* A10; - double* A11; - double* A12; - double* a10t; - double* alpha11; - double* a12t; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_df = bli_ddotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(d,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -void bli_strsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* one = PASTEMAC(s,1); - float* minus_one = PASTEMAC(s,m1); - float* A10; - float* A11; - float* A12; - float* a10t; - float* alpha11; - float* a12t; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_df = bli_sdotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(s,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( trsv_unf_var1 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var1_amd.c b/frame/2/trsv/bli_trsv_unf_var1_amd.c new file mode 100644 index 0000000000..4f026f2c6a --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var1_amd.c @@ -0,0 +1,638 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A10; \ + ctype* A11; \ + ctype* A12; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a12t; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype rho1; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_behind, f_behind; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_behind = iter; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A12 = a + (i )*rs_at + (i+f)*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 - A12 * x2; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A12, cs_at, rs_at, \ + x2, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_behind = k; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a12t = A11 + (l )*rs_at + (l+1)*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 - a12t * x21; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_behind = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 - A10 * x0; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A10, cs_at, rs_at, \ + x0, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_behind = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a10t = A11 + (l )*rs_at + (0 )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 - a10t * x01; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ +} + +void bli_dtrsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* one = PASTEMAC(d,1); + double* minus_one = PASTEMAC(d,m1); + double* A10; + double* A11; + double* A12; + double* a10t; + double* alpha11; + double* a12t; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_ddotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(d,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +void bli_strsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* one = PASTEMAC(s,1); + float* minus_one = PASTEMAC(s,m1); + float* A10; + float* A11; + float* A12; + float* a10t; + float* alpha11; + float* a12t; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_sdotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(s,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) + diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 7ece8f8470..9eb02781a4 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -228,789 +228,5 @@ void PASTEMAC(ch,varname) \ } \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - double* minus_one = PASTEMAC(d,m1); - double* A01; - double* A11; - double* A21; - double* a01; - double* alpha11; - double* a21; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if ( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_daxpyf_zen_int_16x4; - b_fuse = 4; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_strsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* minus_one = PASTEMAC(s, m1); - float* A01; - float* A11; - float* A21; - float* a01; - float* alpha11; - float* a21; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_saxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ztrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - dcomplex* minus_one = PASTEMAC(z, m1); - dcomplex* A01; - dcomplex* A11; - dcomplex* A21; - dcomplex* a01; - dcomplex* alpha11; - dcomplex* a21; - dcomplex* x0; - dcomplex* x1; - dcomplex* x2; - dcomplex* x01; - dcomplex* chi11; - dcomplex* x21; - dcomplex alpha11_conj; - dcomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(z, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_zaxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ctrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - scomplex* minus_one = PASTEMAC(c, m1); - scomplex* A01; - scomplex* A11; - scomplex* A21; - scomplex* a01; - scomplex* alpha11; - scomplex* a21; - scomplex* x0; - scomplex* x1; - scomplex* x2; - scomplex* x01; - scomplex* chi11; - scomplex* x21; - scomplex alpha11_conj; - scomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(c, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_caxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -#else INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var2_amd.c b/frame/2/trsv/bli_trsv_unf_var2_amd.c new file mode 100644 index 0000000000..51bbcabab7 --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var2_amd.c @@ -0,0 +1,1024 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + bli_init_once(); \ +\ + if( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A01; \ + ctype* A11; \ + ctype* A21; \ + ctype* a01; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype minus_chi11; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_ahead, f_ahead; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_ahead = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A01 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_ahead = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a01 = A11 + (0 )*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x01 = x01 - chi11 * a01; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + } \ +\ + /* x0 = x0 - A01 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A01, rs_at, cs_at, \ + x1, incx, \ + x0, incx, \ + cntx \ + ); \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_ahead = m - iter - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_ahead = f - k - 1; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a21 = A11 + (l+1)*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x21 = x21 - chi11 * a21; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + } \ +\ + /* x2 = x2 - A21 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A21, rs_at, cs_at, \ + x1, incx, \ + x2, incx, \ + cntx \ + ); \ + } \ + } \ +} + +void bli_dtrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* minus_one = PASTEMAC(d,m1); + double* A01; + double* A11; + double* A21; + double* a01; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if ( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_strsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* minus_one = PASTEMAC(s, m1); + float* A01; + float* A11; + float* A21; + float* a01; + float* alpha11; + float* a21; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_saxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ztrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + dcomplex* minus_one = PASTEMAC(z, m1); + dcomplex* A01; + dcomplex* A11; + dcomplex* A21; + dcomplex* a01; + dcomplex* alpha11; + dcomplex* a21; + dcomplex* x0; + dcomplex* x1; + dcomplex* x2; + dcomplex* x01; + dcomplex* chi11; + dcomplex* x21; + dcomplex alpha11_conj; + dcomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(z, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_zaxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ctrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + scomplex* minus_one = PASTEMAC(c, m1); + scomplex* A01; + scomplex* A11; + scomplex* A21; + scomplex* a01; + scomplex* alpha11; + scomplex* a21; + scomplex* x0; + scomplex* x1; + scomplex* x2; + scomplex* x01; + scomplex* chi11; + scomplex* x21; + scomplex alpha11_conj; + scomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(c, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_caxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index 4b7711ed4e..e9d7da7b8e 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -12,7 +12,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_packm.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_prune.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_b.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_var.c @@ -26,7 +25,23 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_fpa.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_oapi.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_tapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_smart_threading.c + ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ) +endif() set(SUBDIRECTORIES "gemm" "hemm" "her2k" "herk" "symm" "syr2k" "syrk" "trmm" "trmm3" "trsm" "gemmt") diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index b64da054c9..b65edfcaac 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,3 +98,6 @@ #include "bli_trmm3.h" #include "bli_trsm.h" #include "bli_gemmt.h" + +// Smart Threading API's. +#include "bli_l3_smart_threading.h" diff --git a/frame/3/bli_l3_check.c b/frame/3/bli_l3_check.c index 945b267fda..43ba867283 100644 --- a/frame/3/bli_l3_check.c +++ b/frame/3/bli_l3_check.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -323,8 +324,10 @@ void bli_gemm_basic_check // When mixing datatypes, make sure that alpha does not have a non-zero // imaginary component. - if ( bli_obj_dt( c ) != bli_obj_dt( a ) || - bli_obj_dt( c ) != bli_obj_dt( b ) || + // To support dzgemm, we continue execution when datatypes of C and A + // do not match instead of aborting with an error message. + // Non-zero imaginary component of alpha is handled while packing B. + if ( bli_obj_dt( c ) != bli_obj_dt( b ) || bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) if ( !bli_obj_imag_is_zero( alpha ) ) { diff --git a/frame/3/bli_l3_smart_threading.c b/frame/3/bli_l3_smart_threading.c new file mode 100644 index 0000000000..e4b9b43e24 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.c @@ -0,0 +1,557 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_l3_smart_threading.h" + +#ifdef AOCL_DYNAMIC + +// Utility functions. +static inline dim_t next_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == nt) + { + return part_nt; + } + + dim_t nt_temp = part_nt + 1; + while ( ( nt_temp <= nt ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp++; + } + return nt_temp; +} + +static inline dim_t prev_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == 1) + { + return part_nt; + } + + dim_t nt_temp = part_nt - 1; + while ((nt_temp >= 1) && ((nt % nt_temp) != 0)) + { + nt_temp--; + } + return nt_temp; +} +// End utility functions. + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ); + +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ); + +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + // Sanity check, max available threads should be atleast 4 for the + // smart threading/factorization to be meaningful. For nt < 4 the + // default ic,jc factorization holds good. + if ( ( m <= 1 ) || ( n <= 1 ) || ( k <= 1 ) || ( max_available_nt < 4 ) ) + { + return ret_val; + } + + if ( bli_is_float( dt ) ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other data types not supported for now. + } + + if ( ret_val == BLIS_SUCCESS ) + { + // This is a workaround to ensure that auto_factor attribute of rntm_t + // is not set to TRUE inside bli_rntm_set_ways_from_rntm_sup. Also + // the nt value will be properly set to ic*jc towards the end of + // bli_rntm_set_ways_from_rntm_sup. + bli_rntm_set_num_threads_only( -1, rntm ); + } + + return ret_val; +} + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + arch_t id = bli_arch_query_id(); + if ( id == BLIS_ARCH_ZEN3 ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_zen3 + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other architectures not supported for now. + } + + return ret_val; +} + +// open zen3 region. +#define NUM_CORES_PER_CCD_ZEN3 8 + +// Determines the optimal number of threads (nt) and corresponding work split +// (ic,jc factorization of nt) for gemm on zen3 machines. +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_SUCCESS; + + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + dim_t ic = -1; + dim_t jc = -1; + + bli_thread_partition_2x2( max_available_nt, m, n, &ic, &jc ); + + dim_t jc_per_ccd = ( NUM_CORES_PER_CCD_ZEN3 + ic - 1 ) / ic ; + dim_t b_mat_data_per_ccd = jc_per_ccd * ( n / jc ); + + // All the cores (8) on a CCD share a L3 cache and hence total data + // loaded by the cores on a CCD should be < NC to avoid L3 contention. + // In cases where it is violated, it is better to increase ic and + // reduce B data per CCD, using micro panels mu, nu for thread + // partitioning can help achieve this. Avoiding further ic,jc + // adjustment in this case. + if ( b_mat_data_per_ccd > NC ) + { + const dim_t mu = m / MR; + const dim_t nu = n / NR; + bli_thread_partition_2x2( max_available_nt, mu, nu, &ic, &jc ); + } + else + { + // Adjust the ic,jc in the best match so that m_ic and n_jc + // turns out to be more cache friendly. + bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + m, n, k, max_available_nt, &ic, &jc, MR, NR, MC, KC + ); + } + + ret_val = bli_check_and_transform_native_to_SUP + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + ic, jc, NR, MC, KC, cntx, rntm + ); + + if ( ret_val == BLIS_SUCCESS ) + { + bli_rntm_set_ic_ways_only( ic, rntm ); + bli_rntm_set_jc_ways_only( jc, rntm ); + } + + return ret_val; +} + +// The factorization of nt into ic,jc is based on m and n values (for simplicity +// it can be assumed to be based on m:n ratio). It does not take into account +// how the matrices are loaded into cache or which matrix goes to the larger +// cache. Depending on the matrix dimensions, increasing the ic can result in +// reduced loads from main memory to L2 cache for A matrix without any impact on +// B matrix load (since B is streamed into L3, which is larger). Similary +// adjusting jc can result in B matrix panels fitting perfectly within the L1 +// cache.This function makes these adjustments on ic,jc. +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ) +{ + const dim_t m_ic = m / ( *ic ); + const dim_t n_jc = n / ( *jc ); + const int64_t cur_work_per_thread = m_ic + n_jc; + + // The next and prev factors are caluclated with respect to the current + // factor part of nt. In effect + // 1. next ic * prev jc = nt + // 2. prev ic * next jc = nt + // 3. ic * jc = nt + const dim_t next_ic = next_factor( nt, ( *ic ) ); + const dim_t prev_ic = prev_factor( nt, ( *ic ) ); + const dim_t next_jc = next_factor( nt, ( *jc ) ); + const dim_t prev_jc = prev_factor( nt, ( *jc ) ); + + const dim_t m_next_ic = m / next_ic; + const dim_t m_prev_ic = m / prev_ic; + const dim_t n_next_jc = n / next_jc; + const dim_t n_prev_jc = n / prev_jc; + const dim_t n_jc_modulo_NR = n_jc % NR; + const dim_t n_prev_jc_modulo_NR = n_prev_jc % NR; + + const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; + const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + + const dim_t MCx2 = MC * 2; + const dim_t NRx4 = NR * 4; + const dim_t NRx8 = NR * 8; + + // MC will be reduced if the following mods are zero. Incrementing jc + // helps in this case. + const dim_t n_mod_256 = n % 256; + const dim_t k_mod_256 = k % 256; + + const dim_t k_factor = k / KC; + + bool can_increase_jc = FALSE; + bool can_increase_ic = FALSE; + + // jc adjustment towards next highest factor if it results in n_jc*KC + // fittting completely within l1d cache. Only done if ic prev factor + // does not move m_prev_ic out of good l2 load zone (MC). + // Performance improvement also observed when n_jc is a multiple of NR. + if ( ( ( *ic ) > 1 ) && ( ( *jc ) < nt ) ) + { + // Check whether m_prev_ic remains in good l2 load zone. + if ( ( ( ( m_ic <= MC ) && ( m_prev_ic <= MC ) ) || + ( m_ic > MC ) ) && + ( ( n_jc > NR ) && ( n_next_jc == NR ) ) ) + { + can_increase_jc = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( ( *ic ) < nt ) && ( ( *jc ) > 1) ) + { + // ic adjustment towards next highest factor if it results in + // m_next_ic <= MC. This helps in reducing number of A matrix + // loads per thread to l2 from main memory. + if ( ( m_ic > MC ) && ( m_next_ic <= MC ) && + ( m_next_ic >= MR ) && ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // ic adjustment towards next highest factor resulted in better + // performance when m is sufficiently larger than n and jc prev + // factor did not result in n_prev_jc moving out of good l2 + // load zone (n_jc < 64). + else if ( ( m > ( 5 * n ) ) && ( m_ic >= MCx2 ) && ( k_factor > 4 ) && + ( ( n_jc > NRx4 ) || + ( ( n_jc <= NRx4 ) && ( n_prev_jc <= NRx4 ) ) ) ) + { + can_increase_ic = TRUE; + } + // Performance improvement also observed when n_jc is a multiple + // of NR. + else if ( ( n_jc_modulo_NR != 0 ) && ( n_prev_jc_modulo_NR == 0 ) && + ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( n_mod_256 == 0 ) && ( k_mod_256 == 0 ) && ( k > KC ) ) + { + if ( can_increase_ic == TRUE ) + { + can_increase_ic = FALSE; + } + else if ( can_increase_jc == FALSE ) + { + can_increase_jc = TRUE; + } + } + // If only one of either n or k is a multiple of 256, favour jc if n per + // thread is within a heuristic factor of NR. + else if ( ( ( n_mod_256 == 0 ) || ( k_mod_256 == 0 ) ) && ( k > KC ) ) + { + if ( ( can_increase_ic == TRUE ) && ( n_jc <= NRx8 ) ) + { + can_increase_ic = FALSE; + } + else if ( ( can_increase_jc == FALSE ) && ( n_next_jc <= NRx8 ) ) + { + can_increase_jc = TRUE; + } + } + + // Increasing ic factor is given a higher priority compared to jc + // since it was observed that the A matrix loads (main memory -> l2) had + // more impact on perf compared to B matrix (main memory -> l3 -> l1) + // for the sizes considered. + if ( can_increase_ic ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( next_ic >= prev_jc ) ) || + ( ( m <= n ) && ( next_ic <= prev_jc ) ) ) + { + *ic = next_ic; + *jc = prev_jc; + } + } + else if ( can_increase_jc ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( prev_ic >= next_jc ) ) || + ( ( m <= n ) && ( prev_ic <= next_jc ) ) ) + { + *ic = prev_ic; + *jc = next_jc; + } + } +} + +// It was observed that the SUP thresholds can be lowered and applied on a +// per thread basis in multi threaded scenarios. +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + dim_t m_ic; + dim_t n_jc; + + const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); + const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); + const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + + const dim_t MT_2 = MT / 2; + const dim_t NTx4 = NT * 4; + const dim_t NRx8 = NR * 8; + + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_float = page_size / ( dim_t ) elem_size; + const dim_t page_size_b_floatx2 = page_size_b_float * 2; + + // Default SUP check without considering per thread dimensions. + if ( ( k < KT ) || ( m < MT ) || ( n < NT ) ) + { + ret_val = BLIS_SUCCESS; + } + // Per thread SUP limit checking. It was observed that when k is large, + // (twice page size) moving native to SUP did not help even if m_ic or + // n_jc were within SUP limits. + else if ( ( m >= MT ) && ( n >= NT ) && ( k < page_size_b_floatx2 ) ) + { + m_ic = m / ic; + n_jc = n / jc; + // In multi-threaded scenario, it was observed that if the per + // thread m dimension(A matrix) and n dimension(B matrix) is + // within a factor of SUP limits, SUP path without packing + // resulted in gains. Along similar lines, if the B matrix is + // large enough and reuse is good, packing B matrix alone in SUP + // resulted in perf gains. + if ( ( m_ic <= MT_2 ) && ( n_jc < NTx4 ) ) + { + if ( ( k > KC ) && + ( m_ic >= MC ) && ( n_jc >= NT ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else if ( ( n_jc < NT ) && ( m_ic <= MT ) ) + { + if ( ( k > KC ) && ( m_ic >= MC ) && ( n_jc >= NRx8 ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else + { + ret_val = BLIS_FAILURE; + } + } + else + { + ret_val = BLIS_FAILURE; + } + + return ret_val; +} +// close zen3 region. + +#endif diff --git a/frame/3/bli_l3_smart_threading.h b/frame/3/bli_l3_smart_threading.h new file mode 100644 index 0000000000..48a0a17bb2 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef AOCL_DYNAMIC + +#ifndef BLIS_L3_SMART_THREADING_H +#define BLIS_L3_SMART_THREADING_H + +// Smart threading encompasses the following multi-threading related +// optimizations: +// 1. Selection of optimal number of threads (BLIS_NUM_THREADS) based +// on matrix dimensions. +// 2. Factorization of threads along m and n dimensions (BLIS_IC_NT, +// BLIS_JC_NT) based on matrix dimensions and cache friendliness. +// 3. Transformation of native to SUP path based on the per thread matrix +// dimensions after thread factorization, given that per thread dimensions +// are within SUP limits. +// 4. Enabling packing of B alone in SUP path if native -> SUP path +// transformation happened and depending on per thread matrix dimensions. +// This function captures smart threading logic fine tuned for gemm SUP path. +// Optimal thread selection is not enabled now. +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +#endif //BLIS_L3_SMART_THREADING_H + +#endif diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index a7d7a7874a..d23df8c1e5 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,6 +101,34 @@ err_t bli_gemmsup // that function assumes the context pointer is valid. if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + +#ifdef AOCL_DYNAMIC + // Calculating optimal nt and corresponding factorization (ic,jc) here, so + // as to determine the matrix dimensions (A - m, B - n) per thread. This + // can be used to check if dimensions per thread falls under the SUP + // threshold and potentially move some of the native path gemm to SUP path + // in multi-threaded scenario. + err_t smart_threading = bli_smart_threading_sup( a, b, c, BLIS_GEMM, rntm, cntx ); + + if ( smart_threading != BLIS_SUCCESS ) + { + thresh_func_ft func_fp; + func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); + + // Return early if the sizes are beyond SUP thresholds + if ( !func_fp( a, b, c, cntx ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, + "SUP - Sizes are beyond SUP thresholds."); + return BLIS_FAILURE; + } + } +#else thresh_func_ft func_fp; func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); @@ -110,26 +138,7 @@ err_t bli_gemmsup AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds."); return BLIS_FAILURE; } - - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_l; - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } - else { rntm_l = *rntm; rntm = &rntm_l; } - -#if 0 -const num_t dt = bli_obj_dt( c ); -const dim_t m = bli_obj_length( c ); -const dim_t n = bli_obj_width( c ); -const dim_t k = bli_obj_width_after_trans( a ); -const dim_t tm = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); -const dim_t tn = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); -const dim_t tk = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); - -printf( "dims: %d %d %d (threshs: %d %d %d)\n", - (int)m, (int)n, (int)k, (int)tm, (int)tn, (int)tk ); #endif - // We've now ruled out the following two possibilities: // - the ukernel prefers the operation as-is, and the sup thresholds are // unsatisfied. diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 7ef4bdd49f..909f480599 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -48,120 +48,6 @@ err_t bli_gemmsup_int { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); -#ifdef BLIS_CONFIG_EPYC - const num_t dt = bli_obj_dt( c ); - const dim_t m = bli_obj_length( c ); - const dim_t n = bli_obj_width( c ); - const dim_t k = bli_obj_width( a ); - const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); - const bool auto_factor = bli_rntm_auto_factor( rntm ); - const dim_t n_threads = bli_rntm_num_threads( rntm ); - - dim_t jc_new; - dim_t ic_new; - - - //bli_gemmsup_ref_var2 - //bli_gemmsup_ref_var1 - #if 0 - bli_gemmsup_ref_var1n - #else - #endif - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); - const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || - stor_id == BLIS_RRC || - stor_id == BLIS_RCR || - stor_id == BLIS_CRR ); - #ifdef TRACEVAR - if ( bli_thread_am_ochief( thread ) ) - printf( "bli_l3_sup_int(): var2m primary\n" ); - #endif - - // Don't use the small/unpacked implementation if one of the matrices - // uses general stride. - if ( stor_id == BLIS_XXX ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); - return BLIS_FAILURE; - } - - if ( is_rrr_rrc_rcr_crr ) - { - // This branch handles: - // - rrr rrc rcr crr for row-preferential kernels - // - rcc crc ccr ccc for column-preferential kernels - // - Currently only row-preferential kernels are only supported. - - // calculate number of micropanels in m and n dimensions and - // recalculate the automatic thread factorization based on these number of micropanels - const dim_t mu = m / MR; - const dim_t nu = n / NR; - - // If the parallel thread factorization was automatic, we update it - // with a new factorization based on the matrix dimensions in units - // of micropanels. - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /*Enable packing for B matrix for higher sizes*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_b( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - else - { - // This branch handles: - // - rrr rrc rcr crr for column-preferential kernels - // - rcc crc ccr ccc for row-preferential kernels - // - Currently only row-preferential kernels are only supported. - const dim_t mu = n / MR; // the n becomes m after a transposition - const dim_t nu = m / NR; // the m becomes n after a transposition - - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /* Enable packing for B matrix for higher sizes. Note that pack A - * becomes pack B inside var2m because this is transpose case*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_a( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); - return BLIS_SUCCESS; - -#else // #ifdef BLIS_CONFIG_EPYC - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); // Don't use the small/unpacked implementation if one of the matrices @@ -335,8 +221,6 @@ err_t bli_gemmsup_int // Return success so that the caller knows that we computed the solution. AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return BLIS_SUCCESS; - -#endif } // ----------------------------------------------------------------------------- @@ -401,15 +285,9 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT - -#else if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. @@ -472,14 +350,10 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt -#else + if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c new file mode 100644 index 0000000000..e00cc54ad0 --- /dev/null +++ b/frame/3/bli_l3_sup_int_amd.c @@ -0,0 +1,370 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); + + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + dim_t jc_new; + dim_t ic_new; + + + //bli_gemmsup_ref_var2 + //bli_gemmsup_ref_var1 + #if 0 + bli_gemmsup_ref_var1n + #else + #endif + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + if ( is_rrr_rrc_rcr_crr ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + // - Currently only row-preferential kernels are only supported. + + // calculate number of micropanels in m and n dimensions and + // recalculate the automatic thread factorization based on these number of micropanels + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. However in case smart threading is enabled, + // auto_factor will be false. + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /*Enable packing for B matrix for higher sizes*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_b( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + // - Currently only row-preferential kernels are only supported. + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /* Enable packing for B matrix for higher sizes. Note that pack A + * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_a( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); + return BLIS_SUCCESS; + + +} + +// ----------------------------------------------------------------------------- + +err_t bli_gemmtsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); +// AOCL_DTL_LOG_GEMMT_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c); + + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = m; + const dim_t k = bli_obj_width( a ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + bool use_bp = TRUE; + dim_t jc_new; + dim_t ic_new; + + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + /* Enable packing for B matrix for higher sizes. Note that pack B + * * becomes pack A inside var2m because this is transpose case*/ + if(bli_is_double(dt) && ((n_threads==1))) + { + if((m > 320) && (k > 50)) + bli_rntm_set_pack_b( 1, rntm ); + } + + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() + bli_gemmtsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n primary\n" ); + #endif + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() + bli_gemmtsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of nc up to be a multiple of mr. + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + + /* Enable packing for A matrix for higher sizes. Note that pack A + * * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_double(dt) && (n_threads==1)) + { + if((m > 320) && (k > 50)) + bli_rntm_set_pack_a( 1, rntm ); + } + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m non-primary\n" ); + #endif + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans + bli_gemmtsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n non-primary\n" ); + #endif + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans + bli_gemmtsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of mc up to be a multiple of nr. + } + } + + // Return success so that the caller knows that we computed the solution. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return BLIS_SUCCESS; +} + diff --git a/frame/3/gemm/CMakeLists.txt b/frame/3/gemm/CMakeLists.txt index 8eb115d1f0..825dd745ca 100644 --- a/frame/3/gemm/CMakeLists.txt +++ b/frame/3/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,7 +6,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_cntl.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var2.c @@ -16,4 +15,20 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_packab.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c + ) +endif() + add_subdirectory(ind) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a065156bbf..a9bada995d 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -74,22 +74,6 @@ void bli_gemm_front return; } -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) && - bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) - { - err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - } -#endif - // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); @@ -174,23 +158,6 @@ void bli_gemm_front bli_obj_swap_pack_schemas( &a_local, &b_local ); } - dim_t m_dim_local = bli_obj_length( &c_local ); - dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); -#ifdef BLIS_CONFIG_EPYC - // Regression observed in sgemm native path in cases where m >= 4 * n - // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit - // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for - // the issue. - if( bli_obj_is_float( &c_local ) && - ( n_dim_local >= 1024 ) && - ( k_dim_local >= 1024 ) && - ( m_dim_local >= ( 4 * n_dim_local ) ) ) - { - m_dim_local *= 2; - } -#endif - // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -198,9 +165,9 @@ void bli_gemm_front ( BLIS_GEMM, BLIS_LEFT, // ignored for gemm/hemm/symm - m_dim_local, - n_dim_local, - k_dim_local, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width_after_trans( &a_local ), rntm ); diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c new file mode 100644 index 0000000000..34b41f0568 --- /dev/null +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -0,0 +1,407 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_front + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + bli_init_once(); + + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_dcomplex(c))// This will enable for ZGEMM + { + bli_nthreads_optimum(a, b, c, BLIS_GEMM, rntm); + } + #endif + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + +#ifdef BLIS_ENABLE_GEMM_MD + cntx_t cntx_local; + + // If any of the storage datatypes differ, or if the computation precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: If we ever want to support the caller setting the computation + // domain explicitly, we will need to check the computation dt against the + // storage dt of C (instead of the computation precision against the + // storage precision of C). + if ( bli_obj_dt( &c_local ) != bli_obj_dt( &a_local ) || + bli_obj_dt( &c_local ) != bli_obj_dt( &b_local ) || + bli_obj_comp_prec( &c_local ) != bli_obj_prec( &c_local ) ) + { + // Handle mixed datatype cases in bli_gemm_md(), which may modify + // the objects or the context. (If the context is modified, cntx + // is adjusted to point to cntx_local.) + bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); + } + //else // homogeneous datatypes +#endif + + // Load the pack schemas from the context and embed them into the objects + // for A and B. (Native contexts are initialized with the correct pack + // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would + // have made a copy and modified the schemas, so reading them from the + // context should be a safe bet at this point.) This is a sort of hack for + // communicating the desired pack schemas to bli_gemm_cntl_create() (via + // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us + // to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + const pack_t schema_a = bli_cntx_schema_a_block( cntx ); + const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Next, we handle the possibility of needing to typecast alpha to the + // computation datatype and/or beta to the storage datatype of C. + + // Attach alpha to B, and in the process typecast alpha to the target + // datatype of the matrix (which in this case is equal to the computation + // datatype). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); + + // Attach beta to C, and in the process typecast beta to the target + // datatype of the matrix (which in this case is equal to the storage + // datatype of C). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, beta, &c_local ); + + // Change the alpha and beta pointers to BLIS_ONE since the values have + // now been typecast and attached to the matrices above. + alpha = &BLIS_ONE; + beta = &BLIS_ONE; + +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + dim_t m_dim_local = bli_obj_length( &c_local ); + dim_t n_dim_local = bli_obj_width( &c_local ); + dim_t k_dim_local = bli_obj_width_after_trans( &a_local ); + + // Regression observed in sgemm native path in cases where m >= 4 * n + // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit + // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for + // the issue. + if( bli_obj_is_float( &c_local ) && + ( n_dim_local >= 1024 ) && + ( k_dim_local >= 1024 ) && + ( m_dim_local >= ( 4 * n_dim_local ) ) ) + { + m_dim_local *= 2; + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + m_dim_local, + n_dim_local, + k_dim_local, + rntm + ); + + obj_t* cp = &c_local; + obj_t* betap = beta; + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If any of the following conditions are met, create a temporary matrix + // conformal to C into which we will accumulate the matrix product: + // - the storage precision of C differs from the computation precision; + // - the domains are mixed as crr; + // - the storage format of C does not match the preferred orientation + // of the ccr or crc cases. + // Then, after the computation is complete, this matrix will be copied + // or accumulated back to C. + const bool is_ccr_mismatch = + ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_obj_is_col_stored( &c_local ) ); + const bool is_crc_mismatch = + ( bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) && + !bli_obj_is_row_stored( &c_local ) ); + + obj_t ct; + bool use_ct = FALSE; + + // FGVZ: Consider adding another guard here that only creates and uses a + // temporary matrix for accumulation if k < c * kc, where c is some small + // constant like 2. And don't forget to use the same conditional for the + // castm() and free() at the end. + if ( + bli_obj_prec( &c_local ) != bli_obj_comp_prec( &c_local ) || + bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) || + is_ccr_mismatch || + is_crc_mismatch + ) + { + use_ct = TRUE; + } + + // If we need a temporary matrix conformal to C for whatever reason, + // we create it and prepare to use it now. + if ( use_ct ) + { + const dim_t m = bli_obj_length( &c_local ); + const dim_t n = bli_obj_width( &c_local ); + inc_t rs = bli_obj_row_stride( &c_local ); + inc_t cs = bli_obj_col_stride( &c_local ); + + num_t dt_ct = bli_obj_domain( &c_local ) | + bli_obj_comp_prec( &c_local ); + + // When performing the crr case, accumulate to a contiguously-stored + // real matrix so we do not have to repeatedly update C with general + // stride. + if ( bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) ) + dt_ct = BLIS_REAL | bli_obj_comp_prec( &c_local ); + + // When performing the mismatched ccr or crc cases, now is the time + // to specify the appropriate storage so the gemm_md_c2r_ref() virtual + // microkernel can output directly to C (instead of using a temporary + // microtile). + if ( is_ccr_mismatch ) { rs = 1; cs = m; } + else if ( is_crc_mismatch ) { rs = n; cs = 1; } + + bli_obj_create( dt_ct, m, n, rs, cs, &ct ); + + const num_t dt_exec = bli_obj_exec_dt( &c_local ); + const num_t dt_comp = bli_obj_comp_dt( &c_local ); + + bli_obj_set_target_dt( dt_ct, &ct ); + bli_obj_set_exec_dt( dt_exec, &ct ); + bli_obj_set_comp_dt( dt_comp, &ct ); + + // A naive approach would cast C to the comptuation datatype, + // compute with beta, and then cast the result back to the + // user-provided output matrix. However, we employ a different + // approach that halves the number of memops on C (or its + // typecast temporary) by writing the A*B product directly to + // temporary storage, and then using xpbym to scale the + // output matrix by beta and accumulate/cast the A*B product. + //bli_castm( &c_local, &ct ); + betap = &BLIS_ZERO; + + cp = &ct; + } +#endif +#endif + + // Invoke the internal back-end via the thread handler. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + betap, + cp, + cntx, + rntm, + cntl + ); + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If we created a temporary matrix conformal to C for whatever reason, + // we copy/accumulate the result back to C and then release the object. + if ( use_ct ) + { + obj_t beta_local; + + bli_obj_scalar_detach( &c_local, &beta_local ); + + //bli_castnzm( &ct, &c_local ); + bli_xpbym( &ct, &beta_local, &c_local ); + + bli_obj_free( &ct ); + } +#endif +#endif + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +// ----------------------------------------------------------------------------- + +#if 0 + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + const bool a_is_real = bli_obj_is_real( a ); + const bool a_is_comp = bli_obj_is_complex( a ); + const bool b_is_real = bli_obj_is_real( b ); + const bool b_is_comp = bli_obj_is_complex( b ); + const bool c_is_real = bli_obj_is_real( c ); + const bool c_is_comp = bli_obj_is_complex( c ); + + const bool a_is_single = bli_obj_is_single_prec( a ); + const bool a_is_double = bli_obj_is_double_prec( a ); + const bool b_is_single = bli_obj_is_single_prec( b ); + const bool b_is_double = bli_obj_is_double_prec( b ); + const bool c_is_single = bli_obj_is_single_prec( c ); + const bool c_is_double = bli_obj_is_double_prec( c ); + + const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; + const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; + + const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || + bli_obj_domain( c ) != bli_obj_domain( b ); + + ( void )a_is_real; ( void )a_is_comp; + ( void )b_is_real; ( void )b_is_comp; + ( void )c_is_real; ( void )c_is_comp; + ( void )a_is_single; ( void )a_is_double; + ( void )b_is_single; ( void )b_is_double; + ( void )c_is_single; ( void )c_is_double; + ( void )comp_single; ( void )comp_double; + + if ( + //( c_is_comp && a_is_comp && b_is_real ) || + //( c_is_comp && a_is_real && b_is_comp ) || + //( c_is_real && a_is_comp && b_is_comp ) || + //( c_is_comp && a_is_real && b_is_real ) || + //( c_is_real && a_is_comp && b_is_real ) || + //( c_is_real && a_is_real && b_is_comp ) || + //FALSE + TRUE + ) + { + if ( + ( c_is_single && a_is_single && b_is_single && mixeddomain ) || + ( c_is_single && a_is_single && b_is_single && comp_single ) || + ( c_is_single && a_is_single && b_is_single && comp_double ) || + ( c_is_single && a_is_single && b_is_double ) || + ( c_is_single && a_is_double && b_is_single ) || + ( c_is_double && a_is_single && b_is_single ) || + ( c_is_single && a_is_double && b_is_double ) || + ( c_is_double && a_is_single && b_is_double ) || + ( c_is_double && a_is_double && b_is_single ) || + ( c_is_double && a_is_double && b_is_double && comp_single ) || + ( c_is_double && a_is_double && b_is_double && comp_double ) || + ( c_is_double && a_is_double && b_is_double && mixeddomain ) || + FALSE + ) + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + } + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#else +#if 0 + // If any of the storage datatypes differ, or if the execution precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: We could check the exec dt against the storage dt of C, but for + // now we don't support the caller setting the execution domain + // explicitly. + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#endif +#endif + diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 5e0a4ddb70..dc1c3d14dc 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -407,6 +407,22 @@ void PASTEMAC(ch,varname) \ } \ } \ \ +/* Send progress update if the user has enabled it */ \ +if(AOCL_progress_ptr) { \ + /* Running total for current thread */ \ + tls_aoclprogress_counter += m * n * k; \ + /* Send the update only if enough number of elements are processes */ \ + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= AOCL_PROGRESS_FREQUENCY) \ + { \ + tls_aoclprogress_last_update = tls_aoclprogress_counter; \ + AOCL_PROGRESS_DT(*MKSTR(ch), \ + "gemm", \ + tls_aoclprogress_counter, \ + AOCL_gettid(), \ + bli_rntm_num_threads(rntm)); \ + }\ +} \ + \ /* PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ diff --git a/frame/3/gemm/bli_gemm_packab.c b/frame/3/gemm/bli_gemm_packab.c index 3dfed88478..6828725546 100644 --- a/frame/3/gemm/bli_gemm_packab.c +++ b/frame/3/gemm/bli_gemm_packab.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,9 +91,14 @@ void bli_gemm_packb ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); - + obj_t b_pack; + // BY setting family id to BLIS_GEMM_MD, we indicate packing kernels + // to scale alpha while packing. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM_MD, cntl); + // Pack matrix B according to the control tree node. bli_l3_packm ( @@ -103,6 +109,10 @@ void bli_gemm_packb cntl, thread ); + // Once packing of B matrix is done, fall back to GEMM execution. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM, cntl); + // Proceed with execution using packed matrix B. bli_gemm_int diff --git a/frame/3/trmm/CMakeLists.txt b/frame/3/trmm/CMakeLists.txt index 076d7d4a6b..a3845f3858 100644 --- a/frame/3/trmm/CMakeLists.txt +++ b/frame/3/trmm/CMakeLists.txt @@ -1,12 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ll_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_lu_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_rl_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ru_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_xx_ker_var2.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c + ) +endif() diff --git a/frame/3/trmm/bli_trmm_front_amd.c b/frame/3/trmm/bli_trmm_front_amd.c new file mode 100644 index 0000000000..2301b323a7 --- /dev/null +++ b/frame/3/trmm/bli_trmm_front_amd.c @@ -0,0 +1,206 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_front + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( alpha, b ); + return; + } + + // Alias A and B so we can tweak the objects if necessary. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( b, &c_local ); + + // We do not explicitly implement the cases where A is transposed. + // However, we can still handle them. Specifically, if A is marked as + // needing a transposition, we simply induce a transposition. This + // allows us to only explicitly implement the no-transpose cases. Once + // the transposition is induced, the correct algorithm will be called, + // since, for example, an algorithm over a transposed lower triangular + // matrix A moves in the same direction (forwards) as a non-transposed + // upper triangular matrix. And with the transposition induced, the + // matrix now appears to be upper triangular, so the upper triangular + // algorithm will grab the correct partitions, as if it were upper + // triangular (with no transpose) all along. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + +#ifdef BLIS_DISABLE_TRMM_RIGHT + // NOTE: This case casts right-side trmm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from triangular + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the triangular + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_TRMM_RIGHT. + + // NOTE: This case casts right-side trmm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // NOTE: Casting right-side trmm in terms of left side reduces the number + // of macrokernels exercised to two (trmm_ll and trmm_lu). + + // If A is being multiplied from the right, transpose all operands + // so that we can perform the computation as if A were being multiplied + // from the left. + if ( bli_is_right( side ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + +#else + // NOTE: This case computes right-side trmm natively with trmm_rl and + // trmm_ru macrokernels. This code path always gives us the opportunity + // to transpose the entire operation so that the effective storage format + // of the output matrix matches the microkernel's output preference. + // Thus, from a performance perspective, this case is preferred. + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + // NOTE: We disable the optimization for 1x1 matrices since the concept + // of row- vs. column storage breaks down. + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // If A is being multiplied from the right, swap A and B so that + // the matrix will actually be on the right. + if ( bli_is_right( side ) ) + { + bli_obj_swap( &a_local, &b_local ); + } + +#endif + + // Set each alias as the root object. + // NOTE: We MUST wait until we are done potentially swapping the objects + // before setting the root fields! + bli_obj_set_as_root( &a_local ); + bli_obj_set_as_root( &b_local ); + bli_obj_set_as_root( &c_local ); + +#ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads and update in rntm + if(bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRMM, rntm ); + } +#endif + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_TRMM, + side, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // A sort of hack for communicating the desired pach schemas for A and B + // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and + // bli_l3_cntl_create_if()). This allows us to access the schemas from + // the control tree, which hopefully reduces some confusion, particularly + // in bli_packm_init(). + pack_t schema_a = bli_cntx_schema_a_block( cntx ); + pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Invoke the internal back-end. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_TRMM, // operation family id + alpha, + &a_local, + &b_local, + &BLIS_ZERO, + &c_local, + cntx, + rntm, + cntl + ); +} + diff --git a/frame/3/trsm/bli_trsm_xx_ker_var2.c b/frame/3/trsm/bli_trsm_xx_ker_var2.c index de8cad065a..8d2f8689a9 100644 --- a/frame/3/trsm/bli_trsm_xx_ker_var2.c +++ b/frame/3/trsm/bli_trsm_xx_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,6 +87,59 @@ void bli_trsm_xx_ker_var2 cntl, thread ); + + // Send progress update if enabled + if (AOCL_progress_ptr) + { + + // Get the size of block processed in + // this iteration, add it to the accumulated + // total and send the update. + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width(a); + + num_t dt = bli_obj_dt(c); + char dt_c; + + // Running total for current thread. + tls_aoclprogress_counter += m * n * k; + + // Send the update only if number of elements processes so far + // has exceeded the freqency of reporting. + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= + AOCL_PROGRESS_FREQUENCY) + { + + // reset the last update counter for next iteration. + tls_aoclprogress_last_update = tls_aoclprogress_counter; + + switch (dt) + { + case BLIS_FLOAT: + dt_c = 's'; + break; + case BLIS_DOUBLE: + dt_c = 'd'; + break; + case BLIS_SCOMPLEX: + dt_c = 'c'; + break; + case BLIS_DCOMPLEX: + dt_c = 'z'; + break; + default: + dt_c = ' '; + } + + AOCL_PROGRESS_DT(dt_c, + "trsm", + tls_aoclprogress_counter, + AOCL_gettid(), + bli_rntm_num_threads(rntm)); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); } diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 4b3837544f..d10ea1039a 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without @@ -286,8 +286,13 @@ bool bli_cpuid_is_zen3 // we check for all of them. const bool is_arch = - (( model <= 0x0f ) || - (0x30 <= model && model <= 0x3f )); + ( + ( model <= 0x0f ) || // EPYC and ThreadRipper + ( 0x20 <= model && model <= 0x2f ) || // Ryzen 5000 Desktop + ( 0x30 <= model && model <= 0x3f ) || // Trento + ( 0x40 <= model && model <= 0x4f ) || // RMB + ( 0x50 <= model && model <= 0x5f ) // Ryzen 5000 APU + ); if ( !is_arch ) return FALSE; @@ -459,6 +464,58 @@ bool bli_cpuid_is_bulldozer return TRUE; } +// Check (at runtime) if AVX is supported on the current platform, this is to +// ensure that AVX kernels are not used on legacy platforms which results in crash + +// The support for AVX is checked only once (when this API is called first time) +// On subsequent calls the cached value is returned. This is achieved using +// pthread_once mechanism since this information does not change once the library +// is loaded. +static bool is_avx_supported = FALSE; + + +// Determine if the CPU has support for AVX. +void bli_cpuid_check_avx_support( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2; + + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx_supported = FALSE; + } + else + { + is_avx_supported = TRUE; + } +} + +static bli_pthread_once_t once_check_avx_support = BLIS_PTHREAD_ONCE_INIT; + +// Ensure that actual support determincation happens only once +void bli_cpuid_check_avx_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx_support, bli_cpuid_check_avx_support ); +#endif +} + +// API to check if AVX is supported or not on the current platform. +bool bli_cpuid_is_avx_supported( void ) +{ + bli_cpuid_check_avx_support_once(); + + return is_avx_supported; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 62c05ad5ca..47b584c883 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -132,7 +132,7 @@ BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) void get_cpu_name( char *cpu_name ); int vpu_count( void ); - +bool bli_cpuid_is_avx_supported(void); enum { @@ -159,6 +159,8 @@ enum FEATURE_AVX512VL = 0x4000 }; + + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 6a100bbe8e..fbf5654b7a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,9 +49,23 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // We must ensure that global_rntm has been initialized. bli_init_once(); + // Fetch the number of threads based on the order of precedence, + // or the latest value of number of threads, + // if set by the Application using omp_set_num_threads(nt) API. +#ifdef BLIS_ENABLE_OPENMP + dim_t n_threads = omp_get_max_threads(); +#endif + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); + // Update the latest value of number of threads into global rntm structure, + // before copying into local rntm structure. This updated value will be + // used in the subsequent parallel regions. +#ifdef BLIS_ENABLE_OPENMP + global_rntm.num_threads = n_threads; +#endif + *rntm = global_rntm; // Release the mutex protecting global_rntm. @@ -61,14 +75,14 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // ----------------------------------------------------------------------------- void bli_rntm_set_ways_for_op - ( - opid_t l3_op, - side_t side, - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + opid_t l3_op, + side_t side, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { // Set the number of ways for each loop, if needed, depending on what // kind of information is already stored in the rntm_t object. @@ -81,7 +95,7 @@ bli_rntm_print( rntm ); // Now modify the number of ways, if necessary, based on the operation. if ( l3_op == BLIS_TRMM || - l3_op == BLIS_TRSM ) + l3_op == BLIS_TRSM ) { dim_t jc = bli_rntm_jc_ways( rntm ); dim_t pc = bli_rntm_pc_ways( rntm ); @@ -155,12 +169,12 @@ bli_rntm_print( rntm ); } void bli_rntm_set_ways_from_rntm - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -205,6 +219,11 @@ void bli_rntm_set_ways_from_rntm if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to @@ -238,7 +257,7 @@ void bli_rntm_set_ways_from_rntm pc = 1; bli_thread_partition_2x2( nt, m*BLIS_THREAD_RATIO_M, - n*BLIS_THREAD_RATIO_N, &ic, &jc ); + n*BLIS_THREAD_RATIO_N, &ic, &jc ); for ( ir = BLIS_THREAD_MAX_IR ; ir > 1 ; ir-- ) { @@ -276,12 +295,12 @@ void bli_rntm_set_ways_from_rntm } void bli_rntm_set_ways_from_rntm_sup - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -326,6 +345,11 @@ void bli_rntm_set_ways_from_rntm_sup if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to @@ -359,9 +383,9 @@ void bli_rntm_set_ways_from_rntm_sup pc = 1; //bli_thread_partition_2x2( nt, m*BLIS_THREAD_SUP_RATIO_M, - // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); + // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); bli_thread_partition_2x2( nt, m, - n, &ic, &jc ); + n, &ic, &jc ); //printf( "bli_rntm_set_ways_from_rntm_sup(): jc = %d ic = %d\n", (int)jc, (int)ic ); #if 0 @@ -406,9 +430,9 @@ void bli_rntm_set_ways_from_rntm_sup } void bli_rntm_print - ( - rntm_t* rntm - ) + ( + rntm_t* rntm + ) { dim_t af = bli_rntm_auto_factor( rntm ); @@ -420,35 +444,35 @@ void bli_rntm_print dim_t jr = bli_rntm_jr_ways( rntm ); dim_t ir = bli_rntm_ir_ways( rntm ); - printf( "rntm contents nt jc pc ic jr ir\n" ); + printf( "rntm contents nt jc pc ic jr ir\n" ); printf( "autofac? %1d | %4d%4d%4d%4d%4d%4d\n", (int)af, - (int)nt, (int)jc, (int)pc, - (int)ic, (int)jr, (int)ir ); + (int)nt, (int)jc, (int)pc, + (int)ic, (int)jr, (int)ir ); } // ----------------------------------------------------------------------------- dim_t bli_rntm_calc_num_threads_in - ( - bszid_t* restrict bszid_cur, - rntm_t* restrict rntm - ) + ( + bszid_t* restrict bszid_cur, + rntm_t* restrict rntm + ) { - /* // bp algorithm: - bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop - BLIS_KC, // level 1: 4th loop + /* // bp algorithm: + bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop + BLIS_KC, // level 1: 4th loop BLIS_NO_PART, // level 2: pack B - BLIS_MC, // level 3: 3rd loop + BLIS_MC, // level 3: 3rd loop BLIS_NO_PART, // level 4: pack A - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - - ... // pb algorithm: - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - }; */ + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + + ... // pb algorithm: + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + }; */ dim_t n_threads_in = 1; // Starting with the current element of the bszids array (pointed @@ -477,7 +501,7 @@ dim_t bli_rntm_calc_num_threads_in for ( ; *bszid_cur != BLIS_KR; bszid_cur++ ) { const bszid_t bszid = *bszid_cur; - dim_t cur_way = 1; + dim_t cur_way = 1; // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not // BLIS_NO_PART. @@ -498,12 +522,12 @@ dim_t bli_rntm_calc_num_threads_in //application is available in global_rntm data structure. void bli_nthreads_optimum( - obj_t* a, - obj_t* b, - obj_t* c, - opid_t family, - rntm_t* rntm - ) + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm + ) { #ifndef BLIS_ENABLE_MULTITHREADING return; @@ -517,100 +541,147 @@ void bli_nthreads_optimum( if( family == BLIS_GEMM && bli_obj_is_double(c)) { - dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if( k >= 128) { - if(n <= 15) n_threads_ideal = 8; - else n_threads_ideal = 16; + if(n <= 15) + { + if(m < 128) n_threads_ideal = 8; + else if(m < 256) n_threads_ideal = 16; + else if(m < 512) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else if (n <= 64) + { + if(m < 128) n_threads_ideal = 16; + else if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else{ + if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + } } else - { - if(m > 10000) - { - - /* if(n >= 96) n_threads_ideal = 16; */ - /* else n_threads_ideal = 8; */ - - // current logic is only limiting threads to - // less or equal to 64 - limits performance. - - // To deal with larger matrix sizes we need to use - // large number of threads to improve performance - - // Need to derive this upperTH - and - // if matrix -sizes are larger and user wants - // to use higher number of threads - that should be allowed. - - // if (n > UpperTH) n_threads_ideal = n_threads; - if (n > 200 ) n_threads_ideal = 64; - else if ( n > 120 ) n_threads_ideal = 32; - else if ( n > 40 ) n_threads_ideal = 16; - else if ( n > 10 ) n_threads_ideal = 8; - else /* if ( n <= 10) */ n_threads_ideal = 4; - } - else if( m > 1000) - { - if (n <= 10) n_threads_ideal = 4; - else if ( n <= 40 ) n_threads_ideal = 8; - else if ( n <= 120 ) n_threads_ideal = 16; - else if ( n <= 200 ) n_threads_ideal = 32; - else n_threads_ideal = 64; - - /* if(n < 15) n_threads_ideal = 4; */ - /* else n_threads_ideal = 8; */ - } - else if(m > 210) - { - if(n < 10) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else if(m > 150) - { - if(n < 15) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else - { - if(n < 20) n_threads_ideal = 1; - else n_threads_ideal = 4; - } + { + if(m > 10000) + { + // current logic is only limiting threads to + // less or equal to 64 - limits performance. + // To deal with larger matrix sizes we need to use + // large number of threads to improve performance + // Need to derive this upperTH - and + // if matrix -sizes are larger and user wants + // to use higher number of threads - that should be allowed. + + // if (n > UpperTH) n_threads_ideal = n_threads; + if (n > 200 ) n_threads_ideal = 64; + else if ( n > 120 ) n_threads_ideal = 32; + else if ( n > 40 ) n_threads_ideal = 16; + else if ( n > 10 ) n_threads_ideal = 8; + else n_threads_ideal = 4; + } + else if( m > 1000) + { + if (n <= 10) n_threads_ideal = 4; + else if ( n <= 512 ) n_threads_ideal = 8; + else if ( n <= 1024 ) n_threads_ideal = 16; + else if ( n <= 2048 ) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 210) + { + if(n < 10) n_threads_ideal = 4; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 150) + { + if(n < 10) n_threads_ideal = 2; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if( ( m < 34) && (k < 68) && ( n < 34)) + { + n_threads_ideal = 1; + } + else + { //(m<150 && k<128) + if(n < 20) n_threads_ideal = 1; + if(n < 64) n_threads_ideal = 4; + else n_threads_ideal = 8; + } } + } + else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400) ) + { + n_threads_ideal = 8; + } + else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800) ) + { + n_threads_ideal = 16; + } } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { - dim_t n = bli_obj_length(c); - dim_t k = bli_obj_width_after_trans(a); - - if( (( n <= 10) && ( k < 700)) || - (( n <= 20) && ( k <= 190)) || - (( n <= 40) && ( k <= 80)) || - (( n <= 50) && ( k <= 40)) || - (( n <= 60) && ( k <= 20)) - ) - n_threads_ideal = 1; - else - n_threads_ideal = n_threads; + dim_t n = bli_obj_length(c); + dim_t k = bli_obj_width_after_trans(a); + + if( (( n <= 10) && ( k < 700)) || + (( n <= 20) && ( k <= 190)) || + (( n <= 40) && ( k <= 80)) || + (( n <= 50) && ( k <= 40)) || + (( n <= 60) && ( k <= 20)) + ) + n_threads_ideal = 1; + else + n_threads_ideal = n_threads; } - else if( family == BLIS_TRSM && bli_obj_is_double(c)) + else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); - if(m<=512 && n<=512) - n_threads_ideal = 4; +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; +#else + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; +#endif + } + else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if((m>=64) && (m<=256) && (n>=64) && (n<=256)) + { + n_threads_ideal = 8; + } } else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); dim_t product = (n*k)>>4; /* product is derived based on n and k */ - // Limit the number thread for smaller sizes: + + //Limit the number thread for smaller sizes: if(product <= 346) { n_threads_ideal = 1; @@ -621,6 +692,99 @@ void bli_nthreads_optimum( n_threads_ideal = n_threads; } } + else if( family == BLIS_TRMM && bli_obj_is_double(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if(( n <= 32) && (m <= 32)) + { + n_threads_ideal=1; + /*If Side is Left*/ + }else + { + //Left Side + if(bli_obj_is_triangular(a)) + { + if((m < 300)) + { + if (n < 1000) + { + n_threads_ideal=8; + }else if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(m < 600) + { + if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(n < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + }else//Right Side + { + if((n < 300)) + { + if (m < 1000) + { + n_threads_ideal=8; + }else if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(n < 600) + { + if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(m < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + } + } + } dim_t n_threads_opt = bli_min(n_threads, n_threads_ideal); @@ -630,4 +794,84 @@ void bli_nthreads_optimum( return; } + +// Calculates the optimum number of threads along with the factorization +// (ic, jc) using m, n, k dimensions. This function modifies only the local +// copy of rntm with optimum threads. Since global rntm remains unchanged the +// num_threads set by application is available in global_rntm data structure. +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ) +{ + // By default smart threading should be disabled. + err_t ret_val = BLIS_FAILURE; + +#ifndef BLIS_ENABLE_MULTITHREADING + return ret_val; +#endif + + dim_t n_threads = bli_rntm_num_threads( rntm ); + + // For non-openmp based threading, n_threads could be -1. + if ( ( n_threads == -1 ) || ( n_threads == 1 ) ) return ret_val; + + dim_t ic_way = bli_rntm_ic_ways( rntm ); + dim_t jc_way = bli_rntm_jc_ways( rntm ); + + // Dont enable smart threading if the user supplied the factorization. + if( ( ic_way > 0 ) || ( jc_way > 0 ) ) return ret_val; + + // Only supporting sgemm for now. + if ( ( family == BLIS_GEMM ) && bli_obj_is_float( c ) ) + { + dim_t k = bli_obj_width_after_trans(a); + dim_t m = 0; + dim_t n = 0; + + bool trans_A_for_kernel = FALSE; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( + stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR + ); + + // The A and B matrices are swapped based on the storage type in + // var1n2m. Need to account for this when determining ic and jc + // based on m and n dimensions of A and B. + if ( is_rrr_rrc_rcr_crr ) + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + trans_A_for_kernel = bli_obj_has_trans( a ); + } + else + { + m = bli_obj_width( c ); + n = bli_obj_length( c ); + trans_A_for_kernel = bli_obj_has_trans( b ); + } + + // Take default path if transpose is enabled for A matrix. + if ( trans_A_for_kernel == FALSE ) + { + // A successfull call to smart threading api implies smart + // factorization and possibly native -> SUP path conversion. + // Optimal thread selection is not supported yet. + ret_val = bli_gemm_smart_threading_sup( bli_obj_dt( c ), + bli_obj_elem_size( c ), + is_rrr_rrc_rcr_crr, m, n, k, n_threads, + cntx, rntm ); + } + } + return ret_val; +} #endif // AOCL_DYNAMIC diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 5e8e236af6..e28463c5ab 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -400,6 +400,16 @@ void bli_nthreads_optimum opid_t family, rntm_t* rntm ); + +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ); #endif #endif diff --git a/frame/compat/CMakeLists.txt b/frame/compat/CMakeLists.txt index 7c20f5100c..48b66acbcb 100644 --- a/frame/compat/CMakeLists.txt +++ b/frame/compat/CMakeLists.txt @@ -1,17 +1,12 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE -${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amin.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_asum.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemmt.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_ger.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemv.c @@ -20,8 +15,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_herk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_nrm2.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr.c @@ -30,7 +23,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syrk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmv.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_batch.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpby.c @@ -40,6 +32,38 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatcopy2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatadd.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c + ) +endif() + #Add all subdirectories # add_subdirectory(attic) # add_subdirectory(blis) diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index fabed6e72d..b1cf77e7b8 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,211 +98,5 @@ f77_int PASTEF772(i,chx,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -f77_int isamax_ - ( - const f77_int* n, - const float* x, const f77_int* incx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); - - dim_t n0; - float* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_samaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return f77_index; -} - -f77_int idamax_ - ( - const f77_int* n, - const double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); - - dim_t n0; - double* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_damaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return f77_index; -} - -INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) -#else INSERT_GENTFUNC_BLAS( amax, amaxv ) #endif -#endif diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c new file mode 100644 index 0000000000..7f1a771f7c --- /dev/null +++ b/frame/compat/bla_amax_amd.c @@ -0,0 +1,295 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype_x, chx, blasname, blisname ) \ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(chx), *n, *incx) \ +\ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + gint_t bli_index; \ + f77_int f77_index; \ +\ + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ \ + if ( *n < 1 || *incx <= 0 ) { \ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: vector empty") \ + return 0; \ + }\ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + &bli_index, \ + NULL, \ + NULL \ + ); \ +\ + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ \ + f77_index = bli_index + 1; \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + return f77_index; \ +} + +#ifdef BLIS_ENABLE_BLAS + +f77_int isamax_ + ( + const f77_int* n, + const float* x, const f77_int* incx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); + + dim_t n0; + float* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_samaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return f77_index; +} + +f77_int idamax_ + ( + const f77_int* n, + const double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); + + dim_t n0; + double* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_damaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return f77_index; +} + +INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) + +#endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 41885e95d6..1a30f417b3 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,399 +87,6 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void saxpy_ -( - const f77_int* n, - const float* alpha, - const float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_saxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void daxpy_ -( - const f77_int* n, - const double* alpha, - const double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_daxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void caxpy_ -( - const f77_int* n, - const scomplex* alpha, - const scomplex* x, const f77_int* incx, - scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_caxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void zaxpy_ -( - const f77_int* n, - const dcomplex* alpha, - const dcomplex* x, const f77_int* incx, - dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_zaxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -#else INSERT_GENTFUNC_BLAS( axpy, axpyv ) -#endif #endif diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c new file mode 100644 index 0000000000..8a9f0280c6 --- /dev/null +++ b/frame/compat/bla_axpy_amd.c @@ -0,0 +1,462 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void saxpy_ +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_saxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void daxpy_ +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_daxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void caxpy_ +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_caxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void zaxpy_ +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zaxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + + + +#endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index 61df88cf1e..74baba689c 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -88,211 +88,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void scopy_ -( - const f77_int* n, - const float* x, const f77_int* incx, - float* y, const f77_int* incy -) -{ - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (float*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (float*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -void dcopy_ -( - const f77_int* n, - const double* x, const f77_int* incx, - double* y, const f77_int* incy -) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (double*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (double*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_dcopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -INSERT_GENTFUNC_BLAS_CZ(copy, copyv) -#else INSERT_GENTFUNC_BLAS(copy, copyv) #endif -#endif diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c new file mode 100644 index 0000000000..8dc4d5287c --- /dev/null +++ b/frame/compat/bla_copy_amd.c @@ -0,0 +1,285 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy) \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv(n0, (ftype*)x, *incx, x0, incx0); \ + bli_convert_blas_incv(n0, (ftype*)y, *incy, y0, incy0); \ + \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch, blisname, BLIS_TAPI_EX_SUF) \ + (\ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void scopy_ +( + const f77_int* n, + const float* x, const f77_int* incx, + float* y, const f77_int* incy +) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (float*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (float*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_scopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +void dcopy_ +( + const f77_int* n, + const double* x, const f77_int* incx, + double* y, const f77_int* incy +) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (double*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (double*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_dcopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +INSERT_GENTFUNC_BLAS_CZ(copy, copyv) + +#endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 2a0f815217..3c4d8c538f 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,663 +90,11 @@ ftype PASTEF772(ch,blasname,chc) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -float sdot_ - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - float rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_sdotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -double ddot_ - ( - const f77_int* n, - const double* x, const f77_int* incx, - const double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - double rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_ddotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} -#else INSERT_GENTFUNCDOTR_BLAS( dot, dotv ) -#endif #ifdef BLIS_ENABLE_BLAS #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL -#ifdef BLIS_CONFIG_EPYC -scomplex cdotu_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -dcomplex zdotu_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - - -scomplex cdotc_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - -dcomplex zdotc_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - - - - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} -#else INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif #else // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT @@ -801,8 +149,8 @@ void PASTEF772(ch,blasname,chc) \ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif -#endif +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL +#endif // BLIS_ENABLE_BLAS // -- "Black sheep" dot product function definitions -- @@ -876,4 +224,4 @@ double PASTEF77(d,sdot) return rho; } -#endif +#endif // BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c new file mode 100644 index 0000000000..0cdaa6535b --- /dev/null +++ b/frame/compat/bla_dot_amd.c @@ -0,0 +1,841 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return rho; \ +} + +#ifdef BLIS_ENABLE_BLAS +float sdot_ + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + float rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_sdotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +double ddot_ + ( + const f77_int* n, + const double* x, const f77_int* incx, + const double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + double rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_ddotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +scomplex cdotu_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +dcomplex zdotu_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + + +scomplex cdotc_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +dcomplex zdotc_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + + + + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#else // BLIS_DISABLE_COMPLEX_RETURN_INTEL +// For the "intel" complex return type, use a hidden parameter to return the result +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + bli_finalize_auto(); \ +\ + *rhop = rho; \ +} + +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL + + + +// -- "Black sheep" dot product function definitions -- + +// Input vectors stored in single precision, computed in double precision, +// with result returned in single precision. +float PASTEF77(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return ( float ) + ( + ( double )(*sb) + + PASTEF77(d,sdot) + ( + n, + x, incx, + y, incy + ) + ); +} + +// Input vectors stored in single precision, computed in double precision, +// with result returned in double precision. +double PASTEF77(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + double rho; + dim_t i; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + /* Initialization of BLIS is not required. */ + + /* Convert/typecast negative values of n to zero. */ + bli_convert_blas_dim1( *n, n0 ); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + bli_convert_blas_incv( n0, (float*)x, *incx, x0, incx0 ); + bli_convert_blas_incv( n0, (float*)y, *incy, y0, incy0 ); + + rho = 0.0; + + for ( i = 0; i < n0; i++ ) + { + float* chi1 = x0 + (i )*incx0; + float* psi1 = y0 + (i )*incy0; + + bli_ddots( (( double )(*chi1)), + (( double )(*psi1)), rho ); + } + + /* Finalization of BLIS is not required, because initialization was + not required. */ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#endif diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 50aa931a82..406ff69d53 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -300,513 +300,9 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemm_ -( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc -) -{ - - - - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(d), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); - bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1(*m, m0); - bli_convert_blas_dim1(*n, n0); - bli_convert_blas_dim1(*k, k0); - - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (!bamdzen) - { - // This code is duplicated below, however we don't want to move it out of - // this IF block as it will affect the performance on Zen architetures - // Also this is temporary fix which will be replaced later. - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double *)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) - { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - - return; - } - - if (n0 == 1) - { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - return; - } - else if (m0 == 1) - { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double*)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - //cntx_t* cntx = bli_gks_query_cntx(); - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - - // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better - // - -#ifdef AOCL_DYNAMIC - if (nt && ((n0 > 10 ) || (k0 > 10)) ) -#else - if (nt) -#endif - { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. - -#ifdef BLIS_ENABLE_SMALL_MATRIX - - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) - { - err_t status; - if (bli_is_notrans(blis_transa)) - { - status = bli_dgemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - else - { - status = bli_dgemm_small_At ( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } - } - -#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX - - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ - /* ( */ - /* &alphao, */ - /* &ao, */ - /* &bo, */ - /* &betao, */ - /* &co, */ - /* NULL, */ - /* NULL */ - /* ); */ - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); -} // end of dgemm_ - -void zgemm_ - ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* b, const f77_int* ldb, - const dcomplex* beta, - dcomplex* c, const f77_int* ldc - ) -{ - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // default instance peformance tuning is done in zgemm. - // Single instance tuning is done based on env set. - dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); - - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if ( nt ) - { - // Will call parallelized zgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. -#if ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) - && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) - { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } -#endif//ENABLE_INDUCED_METHOD - -// native tuning resulted in better numbers compared to sup in constrained multi-instance -// sup has been enabled for single instance cases. - if(single_instance==1) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); -}// end of zgemm_ - - -INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) -#else INSERT_GENTFUNC_BLAS( gemm,gemm ) -#endif -// Observed a regression in dgemm with this function addition. -// Disabling temporarily. -#if 0 +#if 1 void dzgemm_ ( const f77_char* transa, @@ -883,7 +379,7 @@ void dzgemm_ bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index 25aef8d11f..c9ea83149a 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,8 +55,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ); #ifdef BLIS_ENABLE_BLAS -// Disabling temporarily -#if 0 +#if 1 BLIS_EXPORT_BLAS void dzgemm_ ( const f77_char* transa, \ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c new file mode 100644 index 0000000000..99d7371778 --- /dev/null +++ b/frame/compat/bla_gemm_amd.c @@ -0,0 +1,867 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#define ENABLE_INDUCED_METHOD 0 +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ +\ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} +#endif + +#ifdef BLIS_ENABLE_BLAS +void dgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc +) +{ + + + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(d), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + bli_convert_blas_dim1(*k, k0); + + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_dgemm_ref_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; + } + + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; + } + else if (m0 == 1) + { + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double*)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + //cntx_t* cntx = bli_gks_query_cntx(); + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. + +#ifdef AOCL_DYNAMIC + //For smaller sizes dgemm_small is perfoming better + if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>150)) ) +#else + if (nt) +#endif + { + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + // The code below will be called when number of threads = 1. + +#ifdef BLIS_ENABLE_SMALL_MATRIX + + //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) + if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || + ((n0 <= 10) && (k0 <=10)) ) + { + err_t status = BLIS_FAILURE; + if (bli_is_notrans(blis_transa)) + { + status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_dgemm_small_At ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } + } + +#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX + + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); +} // end of dgemm_ + +void zgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // default instance peformance tuning is done in zgemm. + // Single instance tuning is done based on env set. + //dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); + + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + +#ifdef BLIS_ENABLE_SMALL_MATRIX + + if( ( (nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 ) ) || + ( (nt == 1) && ((( m0 <= 32)||(n0 <= 32)||(k0 <=32)) && ((m0+n0+k0)<=100)) ) + ) + { + err_t status = BLIS_NOT_YET_IMPLEMENTED; + if (bli_is_notrans(blis_transa)) + { + status = bli_zgemm_small(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_zgemm_small_At(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of zgemm_ + + +INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) + + +// Observed a regression in dgemm with this function addition. +// Disabling temporarily. +#if 1 +void dzgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const double* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + const num_t dt_a = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of dzgemm_ +#endif +#endif diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index e9b210bbc1..9dba1b43c4 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -147,844 +147,5 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - const double* x, const f77_int* incx, - const double* beta, - double* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(d), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((double*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - //variant_2 is chosen for column-storage - // and uses axpyf-based implementation - bli_dgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - else - { - //var_1 is chosen for row-storage - //and uses dotxf-based implementation - bli_dgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void sgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - const float* x, const f77_int* incx, - const float* beta, - float* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(s), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((float*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - bli_sgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_sgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void cgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - const scomplex* x, const f77_int* incx, - const scomplex* beta, - scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(c), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((scomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - scomplex rho; - if (bamdzen) - { - bli_cdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - scomplex yval = *y0; - if(!bli_ceq0(*beta)) - { - bli_cscals( *beta, yval ); - } - else - { - bli_csetsc( 0.0, 0.0, &yval); - } - if(!bli_ceq0(*alpha)) - { - bli_caxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_cgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_cgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void zgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* x, const f77_int* incx, - const dcomplex* beta, - dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(z), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - dcomplex rho; - - if (bamdzen) - { - bli_zdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - dcomplex yval = *y0; - if(!bli_zeq0(*beta)) - { - bli_zscals( *beta, yval ); - } - else - { - bli_zsetsc( 0.0, 0.0, &yval); - } - if(!bli_zeq0(*alpha)) - { - bli_zaxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_zgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_zgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -#else INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif -#endif diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c new file mode 100644 index 0000000000..354f45fe1b --- /dev/null +++ b/frame/compat/bla_gemv_amd.c @@ -0,0 +1,963 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); \ + trans_t blis_transa; \ + dim_t m0, n0; \ + dim_t m_y, n_x; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + inc_t rs_a, cs_a; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + m, \ + n, \ + lda, \ + incx, \ + incy \ + ); \ +\ + if (*m == 0 || *n == 0) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Convert/typecast negative values of m and n to zero. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ \ + bli_set_dims_with_trans( blis_transa, m0, n0, &m_y, &n_x ); \ +\ + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ \ + if ( m_y > 0 && n_x == 0 ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n_x, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( m_y, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Set the row and column strides of A. */ \ + rs_a = 1; \ + cs_a = *lda; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + BLIS_NO_CONJUGATE, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + + +#ifdef BLIS_ENABLE_BLAS +void dgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(d), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((double*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + //variant_2 is chosen for column-storage + // and uses axpyf-based implementation + bli_dgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + else + { + //var_1 is chosen for row-storage + //and uses dotxf-based implementation + bli_dgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void sgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(s), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((float*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + bli_sgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_sgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void cgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(c), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((scomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + scomplex rho; + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_cdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + scomplex yval = *y0; + if(!bli_ceq0(*beta)) + { + bli_cscals( *beta, yval ); + } + else + { + bli_csetsc( 0.0, 0.0, &yval); + } + if(!bli_ceq0(*alpha)) + { + bli_caxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_cgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_cgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void zgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(z), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + dcomplex rho; + + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + dcomplex yval = *y0; + if(!bli_zeq0(*beta)) + { + bli_zscals( *beta, yval ); + } + else + { + bli_zsetsc( 0.0, 0.0, &yval); + } + if(!bli_zeq0(*alpha)) + { + bli_zaxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_zgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_zgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + + +#endif diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 30fd857bc7..b9651577eb 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -93,171 +93,5 @@ void PASTEF772(chx,cha,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sscal_ - ( - const f77_int* n, - const float* alpha, - float* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); - dim_t n0; - float* x0; - inc_t incx0; - /* Initialize BLIS. */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float *)alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (float *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dscal_ - ( - const f77_int* n, - const double* alpha, - double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); - dim_t n0; - double* x0; - inc_t incx0; - - /* Initialize BLIS */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen){ - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*) alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (double *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) -#else INSERT_GENTFUNCSCAL_BLAS( scal, scalv ) #endif -#endif diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c new file mode 100644 index 0000000000..178776a149 --- /dev/null +++ b/frame/compat/bla_scal_amd.c @@ -0,0 +1,260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCSCAL +#define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ +\ +void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + ftype_x alpha_cast; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + if (*n == 0 || alpha == NULL) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return ; \ + } \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* NOTE: We do not natively implement BLAS's csscal/zdscal in BLIS. + that is, we just always sub-optimally implement those cases + by casting alpha to ctype_x (potentially the complex domain) and + using the homogeneous datatype instance according to that type. */ \ + PASTEMAC2(cha,chx,copys)( *alpha, alpha_cast ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + &alpha_cast, \ + x0, incx0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sscal_ + ( + const f77_int* n, + const float* alpha, + float* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); + dim_t n0; + float* x0; + inc_t incx0; + /* Initialize BLIS. */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float *)alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (float *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dscal_ + ( + const f77_int* n, + const double* alpha, + double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); + dim_t n0; + double* x0; + inc_t incx0; + + /* Initialize BLIS */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE){ + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*) alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (double *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) + +#endif diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index 6ecb360f95..d653426478 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,190 +83,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sswap_ - ( - const f77_int* n, - float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -/* Call BLIS kernel */ - bli_sswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dswap_ - ( - const f77_int* n, - double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - bli_dswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) - -#else INSERT_GENTFUNC_BLAS( swap, swapv ) #endif -#endif diff --git a/frame/compat/bla_swap_amd.c b/frame/compat/bla_swap_amd.c new file mode 100644 index 0000000000..617c78a4aa --- /dev/null +++ b/frame/compat/bla_swap_amd.c @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sswap_ + ( + const f77_int* n, + float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + /* Call BLIS kernel */ + bli_sswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dswap_ + ( + const f77_int* n, + double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_dswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) + + +#endif diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 654d3530d2..fea7ba6f17 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -380,1167 +380,5 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void strsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - float* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(s), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_FLOAT; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - m0, - (float*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - n0, - (float*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_strsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} - -void dtrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - double* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(d), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DOUBLE; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - m0, - (double*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - n0, - (double*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#if 0 -void ztrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - dcomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(z), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_zscalv_ex - ( - conja, - m0, - (dcomplex*)alpha, - (dcomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva, b[indx]) -#else - - bli_zinvscals(inva, b[indx]) -#endif - } - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_zscalv_ex - ( - conja, - n0, - (dcomplex*)alpha, - (dcomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva ,b[indx * cs_b]) -#else - - bli_zinvscals(inva ,b[indx * cs_b]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=500 && n0<=500) || - (nt && (m0+n0)<128) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -#if 0 -void ctrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - scomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(c), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_SCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_cscalv_ex - ( - conja, - m0, - (scomplex*)alpha, - (scomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx]) -#else - bli_cinvscals(inva, b[indx]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_cscalv_ex - ( - conja, - n0, - (scomplex*)alpha, - (scomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva) -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx * cs_b]) -#else - bli_cinvscals(inva, b[indx * cs_b]) -#endif - - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) -#else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif -#endif diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c new file mode 100644 index 0000000000..f479b5eac0 --- /dev/null +++ b/frame/compat/bla_trsm_amd.c @@ -0,0 +1,1591 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_side, \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *side, *uploa, \ + *transa, *diaga, *m, *n, (void*)alpha, *lda, *ldb); \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + ftype a_conj; \ + conj_t conja = BLIS_NO_CONJUGATE ; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* ----------------------------------------------------------- */ \ + /* TRSM API: AX = B, where X = B */ \ + /* CALL TRSV when X & B are vector and when A is Matrix */ \ + /* Case 1: LEFT : TRSM, B(mxn) = A(mxm) * X(mxn) */ \ + /* Case 2: RIGHT : TRSM, B(mxn) = X(mxn) * A(nxn) */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | A | X | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | LEFT | mxm | mxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mxm | mx1 | mx1 | TRSV | */ \ + /* | m = 1 | 1x1 | 1xn | 1xn | INVSCALS | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | X | A | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | RIGHT | mxn | nxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mx1 | 1x1 | mx1 | Transpose and INVSCALS| */ \ + /* | m = 1 | 1xn | nxn | 1xn | Transpose and TRSV | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* If Transpose(A) uplo = lower then uplo = higher */ \ + /* If Transpose(A) uplo = higher then uplo = lower */ \ + /* ----------------------------------------------------------- */ \ +\ + if( n0 == 1 ) \ + { \ + if( blis_side == BLIS_LEFT ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var2) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var1) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + m0, \ + (ftype*)alpha, \ + b, rs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < m0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx] ); \ + } \ + }\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( m0 == 1 ) \ + { \ + if(blis_side == BLIS_RIGHT) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var1)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var2)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + n0, \ + (ftype*)alpha, \ + b, cs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < n0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \ + }\ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS + +void strsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + float* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(s), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_FLOAT; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + m0, + (float*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + n0, + (float*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* bli_strsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } + } +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(d), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DOUBLE; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + m0, + (double*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + n0, + (double*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if ((nt == 0 && m0 <= 1000 && n0 <= 1000) || + (nt && (m0 + n0) < 320)) + { + err_t status; + status = bli_trsm_small( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } + + // bli_trsm_small_mt is performing better than native multithread + // for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + { + err_t status; + status = bli_trsm_small_mt( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif// BLIS_ENABLE_OPENMP + } // bli_cpuid_is_avx_supported +#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + + +void ztrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + dcomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(z), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva, b[indx]) +#else + + bli_zinvscals(inva, b[indx]) +#endif + } + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva ,b[indx * cs_b]) +#else + + bli_zinvscals(inva ,b[indx * cs_b]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + + if(((nt==0) && (m0<=500) && (n0<=500)) || + (nt && ((m0+n0)<128))) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } + } // bli_cpuid_is_avx_supported} +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + + +void ctrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + scomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'c', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(c), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_SCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_cscalv_ex + ( + conja, + m0, + (scomplex*)alpha, + (scomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx]) +#else + bli_cinvscals(inva, b[indx]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_cscalv_ex + ( + conja, + n0, + (scomplex*)alpha, + (scomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva) +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx * cs_b]) +#else + bli_cinvscals(inva, b[indx * cs_b]) +#endif + + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } + } // bli_cpuid_is_avx_supported +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +#endif diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index d00df2f0be..c9e597c9a6 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -260,5 +260,11 @@ #endif +#ifdef BLIS_OS_WINDOWS + #define BLIS_TLS_TYPE __declspec(thread) +#else + #define BLIS_TLS_TYPE __thread +#endif + #endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 1bac7aa7c4..49c79cb8ae 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -67,6 +67,10 @@ GENTFUNC( scomplex, c, blasname, blisname ) GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) +#define INSERT_GENTFUNC_BLAS_C( blasname, blisname ) \ +\ +GENTFUNC( scomplex, c, blasname, blisname ) + // -- Basic one-operand macro with real domain only -- diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 770f5c5378..9d45aec1ab 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -89,10 +89,14 @@ typedef unsigned long int guint_t; // -- Boolean type -- // NOTE: bool_t is no longer used and has been replaced with C99's bool type. +// Not defining the bool type for C++ code in windows platform to avoid +// duplicate definition build error. #ifdef _WIN32 +#ifndef __cplusplus #undef bool typedef gint_t bool; #endif +#endif // BLIS uses TRUE and FALSE macro constants as possible boolean values, but we // define these macros in terms of true and false, respectively, which are // defined by C99 in stdbool.h. @@ -931,10 +935,11 @@ typedef enum BLIS_TRMM, BLIS_TRSM, BLIS_GEMMT, + BLIS_GEMM_MD, BLIS_NOID } opid_t; -#define BLIS_NUM_LEVEL3_OPS 11 +#define BLIS_NUM_LEVEL3_OPS 12 // -- Blocksize ID type -- diff --git a/frame/thread/bli_l3_decor_openmp.c b/frame/thread/bli_l3_decor_openmp.c index 0bf3ad8547..b01c208a30 100644 --- a/frame/thread/bli_l3_decor_openmp.c +++ b/frame/thread/bli_l3_decor_openmp.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -140,6 +140,17 @@ void bli_l3_thread_decorator bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); #if 1 + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/thread/bli_l3_decor_single.c b/frame/thread/bli_l3_decor_single.c index 12f27ad873..444583e73e 100644 --- a/frame/thread/bli_l3_decor_single.c +++ b/frame/thread/bli_l3_decor_single.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,7 +115,18 @@ void bli_l3_thread_decorator // Create the root node of the thread's thrinfo_t structure. bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - + + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 159a9e802e..097d136e7e 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1604,11 +1604,23 @@ void bli_thread_set_num_threads( dim_t n_threads ) // We must ensure that global_rntm has been initialized. bli_init_once(); + if ( n_threads <= 0 ) + { + n_threads = 1; + } + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); bli_rntm_set_num_threads_only( n_threads, &global_rntm ); +#ifdef BLIS_ENABLE_OPENMP + // In the function bli_rntm_init_from_global() we extract n_threads + // using the API omp_get_max_threads(). Following step ensures that + // omp_get_max_threads returns the same value as set here. + omp_set_num_threads( n_threads ); +#endif + // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); } @@ -1633,20 +1645,46 @@ void bli_thread_init_rntm_from_env // Try to read BLIS_NUM_THREADS first. nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); - // If BLIS_NUM_THREADS was not set, try to read OMP_NUM_THREADS. - if ( nt == -1 ) - nt = bli_env_get_var( "OMP_NUM_THREADS", -1 ); #ifdef BLIS_ENABLE_OPENMP - // If both environment variables are not set - - // number of threads can also be set by the application by calling omp_set_num_threads(nt) - // The next parallel region when encountered will run with number of threads set by the above API. - // We can know about the number of threads by using the API "omp_get_max_threads()" - if (nt == -1) nt = omp_get_max_threads(); - // If application is multithreaded and number of threads is set using omp_set_num_threads(nt) + + // Scenarios: + // 1. If BLIS_NUM_THREADS is set with valid value, set the nt using omp_set_num_threads(nt) + // so that this value can be fetched inside BLIS API as well. + // 2. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued + // omp_set_num_threads(nt) with desired number of threads, + // omp_get_max_threads() API will fetch the number of threads set earlier. + // 3. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, + // but only OMP_NUM_THREADS is set, + // omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. + // 4. If both environment variables are not set, or if they are set with invalid values, and + // omp_set_num_threads(nt) is not issued by application, + // omp_get_max_threads() API will return the number of the cores in the current context. + // // BLIS will rntm->num_threads will also get initialized with the same value. // However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. // But if nested parallelism is enabled - Then each application will launch MT BLIS. + // + // Order of precedence used for number of threads: + // 1. valid value set for BLIS_NUM_THREADS environment variable + // 2. omp_set_num_threads(nt) issued by the application + // 3. valid value set for OMP_NUM_THREADS environment variable + // 4. Number of cores + // + // Note: If nt is not a valid value for omp_set_num_threads(nt) API, number of threads would be set to 1. + // omp_get_max_threads() API will return 1. + // + // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. + + if(nt > 0) + { + omp_set_num_threads(nt); + } + else + { + nt = omp_get_max_threads(); + } + #endif // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. diff --git a/frame/thread/bli_thrinfo_sup.c b/frame/thread/bli_thrinfo_sup.c index e67e8b6426..8ce714547c 100644 --- a/frame/thread/bli_thrinfo_sup.c +++ b/frame/thread/bli_thrinfo_sup.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,8 +167,23 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; thrcomm_t** new_comms = NULL; + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + dim_t parent_nt_in = 0; + + // thrinfo ocomm is not created when neither packa nor packb is + // enabled. Need to derive parent_nt_in without depending on ocomm in + // those cases. + if ( packa || packb ) + { + parent_nt_in = bli_thread_num_threads( thread_par ); + } + else + { + parent_nt_in = bli_rntm_calc_num_threads_in( bszid_par, rntm ); + } - const dim_t parent_nt_in = bli_thread_num_threads( thread_par ); const dim_t parent_n_way = bli_thread_n_way( thread_par ); const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); const dim_t parent_work_id = bli_thread_work_id( thread_par ); @@ -193,50 +208,75 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl //printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl ); - // The parent's chief thread creates a temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + thrinfo_t* thread_chl = NULL; + + // The communicators are only used when either packa or packb is + // enabled. This means that the communicator creation along with the + // overhead from the barriers (required for synchronizing comm across + // threads) are not required when both packa and packb are disabled. + if ( packa || packb ) { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); - else - new_comms = static_comms; - } + // The parent's chief thread creates a temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); + else + new_comms = static_comms; + } + + // Broadcast the temporary array to all threads in the parent's + // communicator. + new_comms = bli_thread_broadcast( thread_par, new_comms ); + + // Chiefs in the child communicator allocate the communicator + // object and store it in the array element corresponding to the + // parent's work id. + if ( child_comm_id == 0 ) + new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); + + bli_thread_barrier( thread_par ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + new_comms[ parent_work_id ], // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + TRUE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); + + bli_thread_barrier( thread_par ); - // Broadcast the temporary array to all threads in the parent's - // communicator. - new_comms = bli_thread_broadcast( thread_par, new_comms ); - - // Chiefs in the child communicator allocate the communicator - // object and store it in the array element corresponding to the - // parent's work id. - if ( child_comm_id == 0 ) - new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); - - bli_thread_barrier( thread_par ); - - // All threads create a new thrinfo_t node using the communicator - // that was created by their chief, as identified by parent_work_id. - thrinfo_t* thread_chl = bli_thrinfo_create - ( - rntm, // rntm - new_comms[ parent_work_id ], // ocomm - child_comm_id, // ocomm_id - child_n_way, // n_way - child_work_id, // work_id - TRUE, // free_comm - *bszid_chl, // bszid - NULL // sub_node - ); - - bli_thread_barrier( thread_par ); - - // The parent's chief thread frees the temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + // The parent's chief thread frees the temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + bli_free_intl( new_comms ); + } + } + else { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - bli_free_intl( new_comms ); + // No communicator is reqiured in cases where neither packa nor + // packb is enabled. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + NULL, // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + FALSE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); } return thread_chl; diff --git a/frame/util/CMakeLists.txt b/frame/util/CMakeLists.txt index c20d7c525d..13fd53fc52 100644 --- a/frame/util/CMakeLists.txt +++ b/frame/util/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -13,4 +13,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_update.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_api_wrap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_progress.c ) diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index 3c4e5722af..f7be273526 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,3 +63,6 @@ // Header file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. #include "bli_util_api_wrap.h" + +// Public interface for the progress feature +#include "bli_util_progress.h" \ No newline at end of file diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 128fba8b87..81300761fb 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -39,7 +39,8 @@ #include "bli_util_api_wrap.h" // wrapper functions to support additional symbols -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { caxpy_( n, ca, cx, incx, cy, incy); @@ -3221,3 +3222,4 @@ void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, con } #endif +#endif diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index f0aff49ff2..78f088e28e 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -35,7 +35,8 @@ // file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API //Level 1 APIs BLIS_EXPORT_BLIS void SROTG(float *sa, float *sb, float *c, float *s); @@ -1729,3 +1730,4 @@ BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols #endif +#endif diff --git a/frame/util/bli_util_progress.c b/frame/util/bli_util_progress.c new file mode 100644 index 0000000000..4097eb1126 --- /dev/null +++ b/frame/util/bli_util_progress.c @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// The progress feature periodically updates the user with current state +// of the operation, We maintain the progress for each thread separately +// following variables are used to store the elements processed in each +// thread using thread local storage. +BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; + +// Store the counter when last update was sent, this is used to implement +// update freqency. +BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + + +// AOCL_progress_ptr contains the pointer to the callback function +// By default it is set to NULL, which effectivly disabled the +// progress feature. +AOCL_progress_callback AOCL_progress_ptr = NULL; + +void AOCL_BLIS_set_progress(AOCL_progress_callback func) +{ + AOCL_progress_ptr = func; +} \ No newline at end of file diff --git a/frame/util/bli_util_progress.h b/frame/util/bli_util_progress.h new file mode 100644 index 0000000000..0e2a63eb1c --- /dev/null +++ b/frame/util/bli_util_progress.h @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_UTIL_PROGRESS_H +#define BLI_UTIL_PROGRESS_H + +// Public interface for the end user. + +typedef dim_t (*AOCL_progress_callback)(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads); + +BLIS_EXPORT_BLIS void AOCL_BLIS_set_progress(AOCL_progress_callback func); + +// Private interfaces for internal use + +extern AOCL_progress_callback AOCL_progress_ptr; + +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + +// Define the frequency of reporting (number of elements). +// Progress update will be sent only after these many +// elements are processed in the current thread. +#define AOCL_PROGRESS_FREQUENCY 1e+9 + +#define MAX_API_NAME_LEN 20 + +// Macro to send update using datatype character and the api name +#define AOCL_PROGRESS_DT(dt, api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%c%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +// Macro to send update using api name alone. +#define AOCL_PROGRESS_NAME(api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +#endif // BLI_UTIL_PROGRESS_H diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 669a3ba89a..434be490d5 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -3,6 +3,8 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_copyv_zen_int.c diff --git a/kernels/zen/1/bli_axpbyv_zen_int.c b/kernels/zen/1/bli_axpbyv_zen_int.c new file mode 100644 index 0000000000..05ef96175a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int.c @@ -0,0 +1,1099 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + const dim_t n_iter_unroll = 4; // num of registers per iteration + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + return; + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 31 ) < n; i += 32 ) + { + // loading input y + y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_ps( betav.v, y0v.v ); + y1v.v = _mm256_mul_ps( betav.v, y1v.v ); + y2v.v = _mm256_mul_ps( betav.v, y2v.v ); + y3v.v = _mm256_mul_ps( betav.v, y3v.v ); + + // y := y' + alpha * x + y0v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 4; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } +} + +/** + * caxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are simple complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_caxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + float alphaR, alphaI, betaR, betaI; + + __m256 alphaRv; + __m256 alphaIv; + __m256 betaRv; + __m256 betaIv; + __m256 xv[4]; + __m256 yv[4]; + __m256 iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( float* ) x; + y0 = ( float* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR aR aR aR aR aR aR aR + // aiv = -aI aI -aI aI -aI aI -aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR -aR aR -aR aR -aR aR -aR + // aiv = aI aI aI aI aI aI aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) // If BLIS_NO_CONJUGATE + { + // alphaRv = aR aR aR aR aR aR aR aR + // alphaIv = -aI aI -aI aI -aI aI -aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_broadcast_ss( &alphaR ); + alphaIv = _mm256_set_ps + ( + alphaI, -alphaI, alphaI, -alphaI, + alphaI, -alphaI, alphaI, -alphaI + ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + else + { + // alphaRv = aR -aR aR -aR aR -aR aR -aR + // alphaIv = aI aI aI aI aI aI aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_set_ps + ( + -alphaR, alphaR, -alphaR, alphaR, + -alphaR, alphaR, -alphaR, alphaR + ); + alphaIv = _mm256_broadcast_ss( &alphaI ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + + // Processing 16 elements per loop, 8 FMAs + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + iv[3] = _mm256_mul_ps( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + yv[3] = _mm256_permute_ps( yv[3], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_ps( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + iv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + xv[3] = _mm256_permute_ps( xv[3], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 12 elements per loop, 6 FMAs + for ( ; ( i + 11 ) < n; i += 12 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 16 elements per loop, 8 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * zaxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_zaxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + double alphaR, alphaI, betaR, betaI; + + __m256d alphaRv; + __m256d alphaIv; + __m256d betaRv; + __m256d betaIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( double* ) x; + y0 = ( double* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR aR aR aR + // aiv = -aI aI -aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR -aR aR -aR + // aiv = aI aI aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) + { + // alphaRv = aR aR aR aR + // alphaIv = -aI aI -aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaIv = _mm256_set_pd( alphaI, -alphaI, alphaI, -alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + else + { + // alphaRv = aR -aR aR -aR + // alphaIv = aI aI aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + alphaIv = _mm256_broadcast_sd( &alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + + // Processing 8 elements per loop, 8 FMAs + for ( i = 0; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + iv[3] = _mm256_mul_pd( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + yv[3] = _mm256_permute_pd( yv[3], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_pd( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + iv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + xv[3] = _mm256_permute_pd( xv[3], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 6 elements per loop, 6 FMAs + for ( ; ( i + 5 ) < n; i += 6 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 3 FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + + y0 += 1*n_elem_per_reg; + x0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/1/bli_axpbyv_zen_int10.c b/kernels/zen/1/bli_axpbyv_zen_int10.c new file mode 100644 index 0000000000..bbfdaf0d6a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int10.c @@ -0,0 +1,709 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t yv[10]; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // Processing 80 elements per loop, 10 FMAs + for ( i = 0; ( i + 79 ) < n; i += 80 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + yv[5].v = _mm256_mul_ps( betav.v, yv[5].v ); + yv[6].v = _mm256_mul_ps( betav.v, yv[6].v ); + yv[7].v = _mm256_mul_ps( betav.v, yv[7].v ); + yv[8].v = _mm256_mul_ps( betav.v, yv[8].v ); + yv[9].v = _mm256_mul_ps( betav.v, yv[9].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 40 elements per loop, 5 FMAs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 32 elements per loop, 4 FMAs + for ( ; ( i + 31 ) < n; i += 32 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 2 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 1 FMA + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; i++ ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 10; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v, y4v, y5v, y6v, y7v, y8v, y9v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // Using 10 FMAs per loop + for ( i = 0; ( i + 39 ) < n; i += 40 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + y5v.v = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + y6v.v = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + y7v.v = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + y8v.v = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + y9v.v = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + y5v.v = _mm256_mul_pd( betav.v, y5v.v ); + y6v.v = _mm256_mul_pd( betav.v, y6v.v ); + y7v.v = _mm256_mul_pd( betav.v, y7v.v ); + y8v.v = _mm256_mul_pd( betav.v, y8v.v ); + y9v.v = _mm256_mul_pd( betav.v, y9v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + y5v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 5*n_elem_per_reg ), + y5v.v + ); + y6v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 6*n_elem_per_reg ), + y6v.v + ); + y7v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 7*n_elem_per_reg ), + y7v.v + ); + y8v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 8*n_elem_per_reg ), + y8v.v + ); + y9v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 9*n_elem_per_reg ), + y9v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), y5v.v ); + _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), y6v.v ); + _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), y7v.v ); + _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), y8v.v ); + _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), y9v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Using 5 FMAs per loop + for ( ; ( i + 19 ) < n; i += 20 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + + x0 += n_elem_per_reg * 5; + y0 += n_elem_per_reg * 5; + } + + // Using 2 FMAs per loop + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + + x0 += n_elem_per_reg * 2; + y0 += n_elem_per_reg * 2; + } + + // Using 1 FMAs per loop + for ( ; ( i + 3 ) < n; i += 4 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + + x0 += n_elem_per_reg * 1; + y0 += n_elem_per_reg * 1; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index 6f953e6f4c..4ef6981cd7 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -75,9 +75,9 @@ void bli_saxpyv_zen_int10 float* restrict y0; __m256 alphav; - __m256 xv[10]; - __m256 yv[10]; - __m256 zv[10]; + __m256 xv[15]; + __m256 yv[15]; + __m256 zv[15]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq0)( *alpha ) ) @@ -95,7 +95,78 @@ void bli_saxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); - for ( i = 0; (i + 79) < n; i += 80 ) + for (i = 0; (i + 119) < n; i += 120) + { + // 120 elements will be processed per loop; 15 FMAs will run per loop. + xv[0] = _mm256_loadu_ps(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_ps(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_ps(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_ps(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_ps(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_ps(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_ps(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_ps(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_ps(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_ps(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_ps(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_ps(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_ps(x0 + 12 * n_elem_per_reg); + xv[13] = _mm256_loadu_ps(x0 + 13 * n_elem_per_reg); + xv[14] = _mm256_loadu_ps(x0 + 14 * n_elem_per_reg); + + yv[0] = _mm256_loadu_ps(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_ps(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_ps(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_ps(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_ps(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_ps(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_ps(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_ps(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_ps(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_ps(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_ps(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_ps(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_ps(y0 + 12 * n_elem_per_reg); + yv[13] = _mm256_loadu_ps(y0 + 13 * n_elem_per_reg); + yv[14] = _mm256_loadu_ps(y0 + 14 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_ps(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_ps(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_ps(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_ps(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_ps(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_ps(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_ps(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_ps(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_ps(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_ps(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_ps(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_ps(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_ps(xv[12], alphav, yv[12]); + zv[13] = _mm256_fmadd_ps(xv[13], alphav, yv[13]); + zv[14] = _mm256_fmadd_ps(xv[14], alphav, yv[14]); + + _mm256_storeu_ps((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_ps((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_ps((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_ps((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_ps((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_ps((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_ps((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_ps((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_ps((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_ps((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_ps((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_ps((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_ps((y0 + 12 * n_elem_per_reg), zv[12]); + _mm256_storeu_ps((y0 + 13 * n_elem_per_reg), zv[13]); + _mm256_storeu_ps((y0 + 14 * n_elem_per_reg), zv[14]); + + x0 += 15 * n_elem_per_reg; + y0 += 15 * n_elem_per_reg; + } + + for (; (i + 79) < n; i += 80 ) { // 80 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -288,9 +359,9 @@ void bli_daxpyv_zen_int10 double* restrict y0 = y; __m256d alphav; - __m256d xv[10]; - __m256d yv[10]; - __m256d zv[10]; + __m256d xv[13]; + __m256d yv[13]; + __m256d zv[13]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) @@ -308,7 +379,70 @@ void bli_daxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); - for ( i = 0; (i + 39) < n; i += 40 ) + for (i = 0; (i + 51) < n; i += 52) + { + // 52 elements will be processed per loop; 13 FMAs will run per loop. + xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_pd(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_pd(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_pd(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_pd(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_pd(x0 + 12 * n_elem_per_reg); + + yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_pd(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_pd(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_pd(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_pd(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_pd(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_pd(y0 + 12 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_pd(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_pd(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_pd(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_pd(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_pd(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_pd(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_pd(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_pd(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_pd(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_pd(xv[12], alphav, yv[12]); + + _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_pd((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_pd((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_pd((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_pd((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_pd((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_pd((y0 + 12 * n_elem_per_reg), zv[12]); + + x0 += 13 * n_elem_per_reg; + y0 += 13 * n_elem_per_reg; + } + + for ( ; (i + 39) < n; i += 40 ) { // 40 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 99ea517104..c210eceff5 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,6 +36,14 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 128-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -44,6 +52,14 @@ typedef union float f[8] __attribute__((aligned(64))); } v8sf_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + /* Union data structure to access AVX registers * One 256-bit AVX register holds 4 DP elements. */ typedef union @@ -78,11 +94,7 @@ void bli_sdotxv_zen_int float* restrict y0; float rho0; - v8sf_t rho0v, rho1v, rho2v, rho3v; - v8sf_t x0v, y0v; - v8sf_t x1v, y1v; - v8sf_t x2v, y2v; - v8sf_t x3v, y3v; + v8sf_t rhov[4], xv[4], yv[4]; // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). @@ -117,45 +129,55 @@ void bli_sdotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_ps(); - rho1v.v = _mm256_setzero_ps(); - rho2v.v = _mm256_setzero_ps(); - rho3v.v = _mm256_setzero_ps(); + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_ps( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_ps( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_ps( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_ps( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[1].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[2].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[3].v); + + v4sf_t inter0, inter1; + + inter0.v = _mm256_extractf128_ps(rhov[0].v,0); + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + + inter0.v = _mm_add_ps(inter0.v, inter1.v); + + inter1.v = _mm_permute_ps(inter0.v, 14); + + inter0.v = _mm_add_ps(inter0.v,inter1.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + - rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + rho0 = inter0.f[0] + inter0.f[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when @@ -206,12 +228,8 @@ void bli_ddotxv_zen_int double* restrict y0; double rho0; - v4df_t rho0v, rho1v, rho2v, rho3v; - v4df_t x0v, y0v; - v4df_t x1v, y1v; - v4df_t x2v, y2v; - v4df_t x3v, y3v; - + v4df_t rhov[4], xv[4], yv[4]; + // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). if ( PASTEMAC(d,eq0)( *beta ) ) @@ -245,44 +263,51 @@ void bli_ddotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_pd(); - rho1v.v = _mm256_setzero_pd(); - rho2v.v = _mm256_setzero_pd(); - rho3v.v = _mm256_setzero_pd(); + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_pd( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + rho0 = inter1.d[0] + inter1.d[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when @@ -307,3 +332,502 @@ void bli_ddotxv_zen_int PASTEMAC(d,axpys)( *alpha, rho0, *rho ); } + + +void bli_zdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict beta, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + dcomplex* restrict x0; + dcomplex* restrict y0; + dcomplex rho0; + + v4df_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(z,eq0)( *beta ) ) + { + PASTEMAC(z,set0s)( *rho ); + } + else + { + PASTEMAC(z,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256d conju = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_pd(yv[0].v, conju); + yv[1].v = _mm256_mul_pd(yv[1].v, conju); + yv[2].v = _mm256_mul_pd(yv[2].v, conju); + yv[3].v = _mm256_mul_pd(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_pd(rhov[4].v, 0x05); + rhov[5].v = _mm256_permute_pd(rhov[5].v, 0x05); + rhov[6].v = _mm256_permute_pd(rhov[6].v, 0x05); + rhov[7].v = _mm256_permute_pd(rhov[7].v, 0x05); + + rhov[0].v = _mm256_addsub_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_pd(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.d[0]; + rho0.imag = inter1.d[1]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(z,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(z,axpys)( *alpha, rho0, *rho ); +} + +void bli_cdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + scomplex* restrict beta, + scomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + scomplex* restrict x0; + scomplex* restrict y0; + scomplex rho0; + + v8sf_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(c,eq0)( *beta ) ) + { + PASTEMAC(c,set0s)( *rho ); + } + else + { + PASTEMAC(c,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256 conju = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_ps(yv[0].v, conju); + yv[1].v = _mm256_mul_ps(yv[1].v, conju); + yv[2].v = _mm256_mul_ps(yv[2].v, conju); + yv[3].v = _mm256_mul_ps(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_ps(rhov[4].v, 0xb1); + rhov[5].v = _mm256_permute_ps(rhov[5].v, 0xb1); + rhov[6].v = _mm256_permute_ps(rhov[6].v, 0xb1); + rhov[7].v = _mm256_permute_ps(rhov[7].v, 0xb1); + + rhov[0].v = _mm256_addsub_ps(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_ps(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_ps(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_ps(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_ps(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[3].v,rhov[0].v); + + v4sf_t inter1, inter2; + + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + inter2.v = _mm256_extractf128_ps(rhov[0].v,0); + + inter1.v = _mm_add_ps(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.f[0] + inter1.f[2]; + rho0.imag = inter1.f[1] + inter1.f[3]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(c,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(c,axpys)( *alpha, rho0, *rho ); +} diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 6c7f52e161..7146e86879 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,23 +36,6 @@ #include "immintrin.h" #include "blis.h" - -/* Union data structure to access AVX registers - One 256-bit AVX register holds 8 SP elements. */ -typedef union -{ - __m256 v; - float f[8] __attribute__((aligned(64))); -} v8sf_t; - -/* Union data structure to access AVX registers -* One 256-bit AVX register holds 4 DP elements. */ -typedef union -{ - __m256d v; - double d[4] __attribute__((aligned(64))); -} v4df_t; - // ----------------------------------------------------------------------------- void bli_sscalv_zen_int10 @@ -66,13 +49,13 @@ void bli_sscalv_zen_int10 { const dim_t n_elem_per_reg = 8; - dim_t i; + dim_t i = 0; float* restrict x0; __m256 alphav; - __m256 xv[10]; - __m256 zv[10]; + __m256 xv[16]; + __m256 zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; @@ -81,16 +64,7 @@ void bli_sscalv_zen_int10 if ( PASTEMAC(s,eq0)( *alpha ) ) { float* zero = bli_s0; -#ifdef BLIS_CONFIG_EPYC - bli_ssetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); f ( @@ -100,7 +74,7 @@ void bli_sscalv_zen_int10 x, incx, cntx ); -#endif + return; } @@ -111,140 +85,218 @@ void bli_sscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); + dim_t option; - for ( i = 0; (i + 79) < n; i += 80 ) + // Unroll and the loop used is picked based on the input size. + if( n < 300) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - zv[5] = _mm256_mul_ps( alphav, xv[5] ); - zv[6] = _mm256_mul_ps( alphav, xv[6] ); - zv[7] = _mm256_mul_ps( alphav, xv[7] ); - zv[8] = _mm256_mul_ps( alphav, xv[8] ); - zv[9] = _mm256_mul_ps( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 39) < n; i += 40 ) + else if( n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; + option = 1; } - - for ( ; (i + 31) < n; i += 32 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 7) < n; i += 8 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; + option = 0; } - for ( ; (i + 0) < n; i += 1 ) + switch(option) { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for ( ; (i + 127) < n; i += 128 ) + { + //Load the input values + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // Perform : x := alpha * x; + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + // Store the result + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_ps( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_ps( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_ps( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_ps( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_ps( alphav, xv[12] ); + zv[13] = _mm256_mul_ps( alphav, xv[13] ); + zv[14] = _mm256_mul_ps( alphav, xv[14] ); + zv[15] = _mm256_mul_ps( alphav, xv[15] ); + + _mm256_storeu_ps( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_ps( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_ps( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_ps( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + case 1 : + + for ( ; (i + 95) < n; i += 96 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + + x0 += 6*n_elem_per_reg; + } + + for ( ; (i + 23) < n; i += 24 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const float alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; @@ -266,13 +318,13 @@ void bli_dscalv_zen_int10 { const dim_t n_elem_per_reg = 4; - dim_t i; + dim_t i = 0; double* restrict x0; __m256d alphav; - __m256d xv[10]; - __m256d zv[10]; + __m256d xv[16]; + __m256d zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; @@ -281,16 +333,7 @@ void bli_dscalv_zen_int10 if ( PASTEMAC(d,eq0)( *alpha ) ) { double* zero = bli_d0; -#ifdef BLIS_CONFIG_EPYC - bli_dsetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); f @@ -301,7 +344,7 @@ void bli_dscalv_zen_int10 x, incx, cntx ); -#endif + return; } @@ -312,140 +355,221 @@ void bli_dscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); + dim_t option; - for ( i = 0; (i + 39) < n; i += 40 ) + // Unroll and the loop used is picked based on the input size. + if(n < 200) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - zv[5] = _mm256_mul_pd( alphav, xv[5] ); - zv[6] = _mm256_mul_pd( alphav, xv[6] ); - zv[7] = _mm256_mul_pd( alphav, xv[7] ); - zv[8] = _mm256_mul_pd( alphav, xv[8] ); - zv[9] = _mm256_mul_pd( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 19) < n; i += 20 ) + else if(n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; + option = 1; } - - for ( ; (i + 15) < n; i += 16 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; + option = 0; } - for ( ; (i + 7) < n; i += 8 ) + switch(option) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 3) < n; i += 4 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for (; (i + 63) < n; i += 64 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_pd( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_pd( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_pd( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_pd( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_pd( alphav, xv[12] ); + zv[13] = _mm256_mul_pd( alphav, xv[13] ); + zv[14] = _mm256_mul_pd( alphav, xv[14] ); + zv[15] = _mm256_mul_pd( alphav, xv[15] ); + + _mm256_storeu_pd( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_pd( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_pd( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_pd( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + for (; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 1: + + for (; (i + 31) < n; i += 32 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + x0 += 8*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 11) < n; i += 12 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const double alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index d2bf13822d..3a77f69ef1 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -7,4 +7,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_5.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c ) diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c new file mode 100644 index 0000000000..cba0141376 --- /dev/null +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -0,0 +1,721 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + + +/** + * daxpy2v kernel performs axpy2v operation. + * z := y + alphax * conjx(x) + alphay * conjy(y) + * where x, y, and z are vectors of length n. + */ +void bli_daxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + double* restrict alphax, + double* restrict alphay, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + if ( bli_zero_dim1( n ) ) return; + + if ( incz == 1 && incx == 1 && incy == 1 ) + { + dim_t i = 0; + dim_t rem = n%4; + const dim_t n_elem_per_reg = 4; + __m256d xv[4], yv[4], zv[4]; + __m256d alphaxv, alphayv; + + alphaxv = _mm256_broadcast_sd((double const*) alphax); + alphayv = _mm256_broadcast_sd((double const*) alphay); + + for( ; (i + 15) < n; i+= 16 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y + 3*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z + 3*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphaxv, zv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphaxv, zv[3]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + zv[2] = _mm256_fmadd_pd(yv[2], alphayv, zv[2]); + zv[3] = _mm256_fmadd_pd(yv[3], alphayv, zv[3]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + _mm256_storeu_pd((z + 2*n_elem_per_reg), zv[2]); + _mm256_storeu_pd((z + 3*n_elem_per_reg), zv[3]); + + z += 4*n_elem_per_reg; + x += 4*n_elem_per_reg; + y += 4*n_elem_per_reg; + } + + for( ; (i + 7) < n; i+= 8 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + + z += 2*n_elem_per_reg; + x += 2*n_elem_per_reg; + y += 2*n_elem_per_reg; + } + + for( ; (i + 3) < n; i+= 4 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + + z += n_elem_per_reg; + x += n_elem_per_reg; + y += n_elem_per_reg; + } + if(rem) + { + PRAGMA_SIMD + for ( i = 0; i < rem; ++i ) + { + PASTEMAC(d,axpys)( *alphax, x[i], z[i] ); + PASTEMAC(d,axpys)( *alphay, y[i], z[i] ); + } + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,axpyv_ker_ft) kfp_av + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); + + kfp_av + ( + conjx, + n, + alphax, + x, incx, + z, incz, + cntx + ); + + kfp_av + ( + conjy, + n, + alphay, + y, incy, + z, incz, + cntx + ); + } +} + +/** + * zaxpy2v kernel performs axpy2v operation. + * z := z + alphax * conjx(x) + alphay * conjy(y) + * where, + * x, y & z are double complex vectors of length n. + * alpha & beta are complex scalers. + */ +void bli_zaxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alphax, + dcomplex* restrict alphay, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // If the vectors are empty or if both alpha are zero, return early + if ( ( bli_zero_dim1( n ) ) || + ( PASTEMAC(z,eq0)( *alphax ) && PASTEMAC(z,eq0)( *alphay ) ) ) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + const dim_t n_elem_per_reg = 4; // Number of elements per register + + dim_t i = 0; // Iterator + + double* restrict x0; + double* restrict y0; + double* restrict z0; + double* restrict alphax0; + double* restrict alphay0; + + // Initialize local pointers. + x0 = (double*) x; + y0 = (double*) y; + z0 = (double*) z; + alphax0 = (double*) alphax; + alphay0 = (double*) alphay; + + if ( incx == 1 && incy == 1 && incz == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR + ixI ) + + // ( ayR + iayI ) * ( yR + iyI ) + // z = ( zR + izI ) + + // ( axR.xR + iaxR.xI + iaxI.xR - axI.xI ) + + // ( xyR.yR + iayR.yI + iayI.yR - ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR - axI.xI ) + i( axR.xI + axI.xR ) ) + + // ( ( ayR.yR - ayI.yI ) + i( ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR - axI.xI + ayR.yR - ayI.yI ) + + // i( zI + axR.xI + axI.xR + ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_NO_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR axR axR axR + // axiv = -axI axI -axI axI + // ayrv = ayR ayR ayR ayR + // ayiv = -ayI ayI -ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR - ixI ) + + // ( ayR + iayI ) * ( yR - iyI ) + // z = ( zR + izI ) + + // ( axR.xR - iaxR.xI + iaxI.xR + axI.xI ) + + // ( xyR.yR - iayR.yI + iayI.yR + ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR + axI.xI ) + i( -axR.xI + axI.xR ) ) + + // ( ( ayR.yR + ayI.yI ) + i( -ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR + axI.xI + ayR.yR + ayI.yI ) + + // i( zI - axR.xI + axI.xR - ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR -axR axR -axR + // axiv = axI axI axI axI + // ayrv = ayR -ayR ayR -ayR + // ayiv = ayI ayI ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + __m256d alphaxRv; + __m256d alphaxIv; + __m256d alphayRv; + __m256d alphayIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d zv[4]; + + double alphaxR, alphaxI; + double alphayR, alphayI; + + alphaxR = alphax->real; + alphaxI = alphax->imag; + alphayR = alphay->real; + alphayI = alphay->imag; + + // Broadcast alphax & alphay to respective vector registers + if ( !bli_is_conj( conjx ) ) // If not x conjugate + { + // alphaxRv = axR axR axR axR + // alphaxIv = -axI axI -axI axI + alphaxRv = _mm256_broadcast_sd( &alphaxR ); + alphaxIv = _mm256_set_pd( alphaxI, -alphaxI, alphaxI, -alphaxI ); + } + else + { + // alphaxRv = axR -axR axR -axR + // alphaxIv = axI axI axI axI + alphaxRv = _mm256_set_pd( -alphaxR, alphaxR, -alphaxR, alphaxR ); + alphaxIv = _mm256_broadcast_sd( &alphaxI ); + } + + if ( !bli_is_conj( conjy ) ) // If not y conjugate + { + // alphayRv = ayR ayR ayR ayR + // alphayIv = -ayI ayI -ayI ayI + alphayRv = _mm256_broadcast_sd( &alphayR ); + alphayIv = _mm256_set_pd( alphayI, -alphayI, alphayI, -alphayI ); + } + else + { + // alphayRv = ayR -ayR ayR -ayR + // alphayIv = ayI ayI ayI ayI + alphayRv = _mm256_set_pd( -alphayR, alphayR, -alphayR, alphayR ); + alphayIv = _mm256_broadcast_sd( &alphayI ); + } + + // Processing 8 elements per loop, 16 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z0 + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z0 + 3*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxRv, zv[3] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + xv[2] = _mm256_permute_pd( xv[2], 5 ); + xv[3] = _mm256_permute_pd( xv[3], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxIv, zv[3] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayRv, zv[3] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + yv[2] = _mm256_permute_pd( yv[2], 5 ); + yv[3] = _mm256_permute_pd( yv[3], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayIv, zv[3] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (z0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (z0 + 3*n_elem_per_reg), zv[3] ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + z0 += 4*n_elem_per_reg; + } + + // Processing 4 elements per loop, 8 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + z0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 4FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + z0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + } + else + { + // Using scalar code for non-unit increments + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index f5a043db84..bb24e6c52f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,29 +95,6 @@ void bli_caxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +118,6 @@ void bli_caxpyf_zen_int_4 ); } -#endif return; } @@ -357,28 +333,6 @@ void bli_zaxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -402,7 +356,6 @@ void bli_zaxpyf_zen_int_4 ); } -#endif return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index f770389196..8fea5f6498 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -108,29 +108,6 @@ void bli_saxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -154,7 +131,6 @@ void bli_saxpyf_zen_int_5 ); } -#endif return; } @@ -349,31 +325,16 @@ void bli_daxpyf_zen_int_5 const dim_t fuse_fac = 5; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 2; dim_t i; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; + double* restrict av[5] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v; - - v4df_t a00v, a01v, a02v, a03v; - v4df_t a04v; - - v4df_t a10v, a11v, a12v, a13v; - v4df_t a14v; - - v4df_t y0v, y1v; - - double chi0, chi1, chi2, chi3; - double chi4; + v4df_t chiv[5], a_vec[20], yv[4]; + + double chi[5]; // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; @@ -382,29 +343,6 @@ void bli_daxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -428,122 +366,245 @@ void bli_daxpyf_zen_int_5 ); } -#endif return; } // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; + // av points to the 5 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); // Scale each chi scalar by alpha. - bli_dscals( *alpha, chi0 ); - bli_dscals( *alpha, chi1 ); - bli_dscals( *alpha, chi2 ); - bli_dscals( *alpha, chi3 ); - bli_dscals( *alpha, chi4 ); + bli_dscals( *alpha, chi[0] ); + bli_dscals( *alpha, chi[1] ); + bli_dscals( *alpha, chi[2] ); + bli_dscals( *alpha, chi[3] ); + bli_dscals( *alpha, chi[4] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); // If there are vectorized iterations, perform them with vector // instructions. if ( inca == 1 && incy == 1 ) { - for ( i = 0; (i + 7) < m; i += 8 ) + // 16 elements of the result are computed per iteration + for ( i = 0; (i + 15) < m; i += 16 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + + a_vec[15].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[16].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); - - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); - - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[15].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[16].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[17].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[18].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[19].v, chiv[4].v, yv[3].v ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * 4; + av[0] += n_elem_per_reg * 4; + av[1] += n_elem_per_reg * 4; + av[2] += n_elem_per_reg * 4; + av[3] += n_elem_per_reg * 4; + av[4] += n_elem_per_reg * 4; + } - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg ); + // 12 elements of the result are computed per iteration + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * 3; + av[0] += n_elem_per_reg * 3; + av[1] += n_elem_per_reg * 3; + av[2] += n_elem_per_reg * 3; + av[3] += n_elem_per_reg * 3; + av[4] += n_elem_per_reg * 3; + } - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + // 8 elements of the result are computed per iteration + for (; (i + 7) < m; i += 8 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); - - y0 += n_iter_unroll * n_elem_per_reg; - a0 += n_iter_unroll * n_elem_per_reg; - a1 += n_iter_unroll * n_elem_per_reg; - a2 += n_iter_unroll * n_elem_per_reg; - a3 += n_iter_unroll * n_elem_per_reg; - a4 += n_iter_unroll * n_elem_per_reg; + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * 2; + av[0] += n_elem_per_reg * 2; + av[1] += n_elem_per_reg * 2; + av[2] += n_elem_per_reg * 2; + av[3] += n_elem_per_reg * 2; + av[4] += n_elem_per_reg * 2; } + // 4 elements of the result are computed per iteration for( ; (i + 3) < m; i += 4 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -551,25 +612,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; + av[0] += 1; + av[1] += 1; + av[2] += 1; + av[3] += 1; + av[4] += 1; y0 += 1; } } @@ -579,25 +640,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; y0 += incy; } @@ -655,29 +716,6 @@ static void bli_daxpyf_zen_int_16x2 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -701,7 +739,6 @@ static void bli_daxpyf_zen_int_16x2 ); } -#endif return; } @@ -966,43 +1003,21 @@ void bli_daxpyf_zen_int_16x4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - if(b_n & 2) - { - bli_daxpyf_zen_int_16x2( conja, - conjx, - m, 2, - alpha, a, inca, lda, - x, incx, - y, incy, - cntx - ); - b_n -= 2; - a += 2*lda; - x += 2 * incx; - } - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); + if (b_n & 2) + { + bli_daxpyf_zen_int_16x2( conja, + conjx, + m, 2, + alpha, a, inca, lda, + x, incx, + y, incy, + cntx + ); + b_n -= 2; + a += 2*lda; + x += 2 * incx; + } - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1026,7 +1041,6 @@ void bli_daxpyf_zen_int_16x4 ); } -#endif return; } @@ -1248,7 +1262,7 @@ void bli_daxpyf_zen_int_16x4 a2 += n_elem_per_reg; a3 += n_elem_per_reg; } -#if 1 + for ( ; (i + 1) < m; i += 2) { @@ -1281,7 +1295,7 @@ void bli_daxpyf_zen_int_16x4 a2 += 2; a3 += 2; } -#endif + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { @@ -1396,29 +1410,6 @@ void bli_caxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1442,7 +1433,6 @@ void bli_caxpyf_zen_int_5 ); } -#endif return; } @@ -1747,8 +1737,17 @@ void bli_caxpyf_zen_int_5 } -// ----------------------------------------------------------------------------- - +//------------------------------------------------------------------------------ +/** + * Following kernel performs axpyf operation on dcomplex data. + * Operate over 5 columns of a matrix at a time and march through + * rows in steps of 4 or 2. + * For optimal performance, it separate outs imaginary and real + * components of chis and broadcast them into separate ymm vector + * registers. + * By doing so it avoids necessity of permute operation to get the + * final result of dcomp-lex multiplication. + */ void bli_zaxpyf_zen_int_5 ( conj_t conja, @@ -1762,391 +1761,523 @@ void bli_zaxpyf_zen_int_5 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 5; - - const dim_t n_elem_per_reg = 2; - const dim_t n_iter_unroll = 2; - - dim_t i = 0; - dim_t setPlusOne = 1; - - v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; - v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; - - v4df_t a00v, a01v, a02v, a03v, a04v; - v4df_t a05v, a06v, a07v, a08v, a09v; - - v4df_t a10v, a11v, a12v, a13v, a14v; - v4df_t a15v, a16v, a17v, a18v, a19v; - - v4df_t y0v, y1v; - v4df_t setMinus, setPlus; - - dcomplex chi0, chi1, chi2, chi3, chi4; - dcomplex* restrict a0; - dcomplex* restrict a1; - dcomplex* restrict a2; - dcomplex* restrict a3; - dcomplex* restrict a4; - - dcomplex* restrict y0; - - - if ( bli_is_conj(conja) ){ - setPlusOne = -1; - } - - // If either dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; - - // If b_n is not equal to the fusing factor, then perform the entire - // operation as a loop over axpyv. - if ( b_n != fuse_fac ) - { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else - zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); - - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - f - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#endif - return; - } - + const dim_t fuse_fac = 5; + + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 2; + + dim_t i = 0; + dim_t setPlusOne = 1; + + v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; + v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; + + v4df_t a00v, a01v, a02v, a03v, a04v; + + v4df_t a10v, a11v, a12v, a13v, a14v; + + v4df_t y0v, y1v, y2v, y3v; + v4df_t r0v, r1v, conjv; + + dcomplex chi0, chi1, chi2, chi3, chi4; + dcomplex* restrict a0; + dcomplex* restrict a1; + dcomplex* restrict a2; + dcomplex* restrict a3; + dcomplex* restrict a4; + + dcomplex* restrict y0; + + + if ( bli_is_conj(conja) ){ + setPlusOne = -1; + } + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + dcomplex *pchi0 = x + 0*incx ; + dcomplex *pchi1 = x + 1*incx ; + dcomplex *pchi2 = x + 2*incx ; + dcomplex *pchi3 = x + 3*incx ; + dcomplex *pchi4 = x + 4*incx ; + + bli_zcopycjs( conjx, *pchi0, chi0 ); + bli_zcopycjs( conjx, *pchi1, chi1 ); + bli_zcopycjs( conjx, *pchi2, chi2 ); + bli_zcopycjs( conjx, *pchi3, chi3 ); + bli_zcopycjs( conjx, *pchi4, chi4 ); + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + bli_zscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0.real ); + chi1v.v = _mm256_broadcast_sd( &chi1.real ); + chi2v.v = _mm256_broadcast_sd( &chi2.real ); + chi3v.v = _mm256_broadcast_sd( &chi3.real ); + chi4v.v = _mm256_broadcast_sd( &chi4.real ); + + chi5v.v = _mm256_broadcast_sd( &chi0.imag ); + chi6v.v = _mm256_broadcast_sd( &chi1.imag ); + chi7v.v = _mm256_broadcast_sd( &chi2.imag ); + chi8v.v = _mm256_broadcast_sd( &chi3.imag ); + chi9v.v = _mm256_broadcast_sd( &chi4.imag ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + // March through vectors in multiple of 4. + for( i = 0; (i + 3) < m; i += 4 ) + { + // Load the input values. + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + r1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + if ( bli_is_conj(conja) ){ + /** + * For conjugate cases imaginary part + * is negated. + */ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a10v.v = _mm256_mul_pd(a10v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a11v.v = _mm256_mul_pd(a11v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a12v.v = _mm256_mul_pd(a12v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a13v.v = _mm256_mul_pd(a13v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + a14v.v = _mm256_mul_pd(a14v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v); + y1v.v = _mm256_mul_pd( a10v.v, chi0v.v); + + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v); + y3v.v = _mm256_mul_pd( a10v.v, chi5v.v); + + /** + * y0v.v & y1v.v holds computation with real part of chi. + * y2v.v & y3v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v & y3v.v + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * r1v.v which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y1v.v = _mm256_mul_pd( a11v.v, chi1v.v ); + + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + y3v.v = _mm256_mul_pd( a11v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v); + y1v.v = _mm256_mul_pd( a12v.v, chi2v.v); + + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + y3v.v = _mm256_mul_pd( a12v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y1v.v = _mm256_mul_pd( a13v.v, chi3v.v ); + + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + y3v.v = _mm256_mul_pd( a13v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y1v.v = _mm256_mul_pd( a14v.v, chi4v.v ); + + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + y3v.v = _mm256_mul_pd( a14v.v, chi9v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + /** + * Final axpy compuation is available in r0v.v + * and r1v.v registers. + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), r1v.v ); + + /** + * Set the pointers next vectors elements to be + * computed based on unroll factor. + */ + y0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + } + // March through vectors in multiple of 2. + for( ; (i + 1) < m; i += 2 ) + { + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); - // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - y0 = y; - - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - - dcomplex *pchi0 = x + 0*incx ; - dcomplex *pchi1 = x + 1*incx ; - dcomplex *pchi2 = x + 2*incx ; - dcomplex *pchi3 = x + 3*incx ; - dcomplex *pchi4 = x + 4*incx ; - - bli_zcopycjs( conjx, *pchi0, chi0 ); - bli_zcopycjs( conjx, *pchi1, chi1 ); - bli_zcopycjs( conjx, *pchi2, chi2 ); - bli_zcopycjs( conjx, *pchi3, chi3 ); - bli_zcopycjs( conjx, *pchi4, chi4 ); - - // Scale each chi scalar by alpha. - bli_zscals( *alpha, chi0 ); - bli_zscals( *alpha, chi1 ); - bli_zscals( *alpha, chi2 ); - bli_zscals( *alpha, chi3 ); - bli_zscals( *alpha, chi4 ); - - // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0.real ); - chi1v.v = _mm256_broadcast_sd( &chi1.real ); - chi2v.v = _mm256_broadcast_sd( &chi2.real ); - chi3v.v = _mm256_broadcast_sd( &chi3.real ); - chi4v.v = _mm256_broadcast_sd( &chi4.real ); - - chi5v.v = _mm256_broadcast_sd( &chi0.imag ); - chi6v.v = _mm256_broadcast_sd( &chi1.imag ); - chi7v.v = _mm256_broadcast_sd( &chi2.imag ); - chi8v.v = _mm256_broadcast_sd( &chi3.imag ); - chi9v.v = _mm256_broadcast_sd( &chi4.imag ); - - // If there are vectorized iterations, perform them with vector - // instructions. - if ( inca == 1 && incy == 1 ) - { - setMinus.v = _mm256_set_pd( -1, 1, -1, 1 ); + if ( bli_is_conj(conja) ){ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - setPlus.v = _mm256_set1_pd( 1 ); - if ( bli_is_conj(conja) ){ - setPlus.v = _mm256_set_pd( -1, 1, -1, 1 ); - } + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - /* - y := y + alpha * conja(A) * conjx(x) + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - nn - (ar + ai) (xr + xi) - ar * xr - ai * xi - ar * xi + ai * xr + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - cc : (ar - ai) (xr - xi) - ar * xr - ai * xi - -(ar * xi + ai * xr) + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - nc : (ar + ai) (xr - xi) - ar * xr + ai * xi - -(ar * xi - ai * xr) + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - cn : (ar - ai) (xr + xi) - ar * xr + ai * xi - ar * xi - ai * xr + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - */ + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - for( i = 0; (i + 3) < m; i += 4 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - y1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); - - a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); - - a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); - - a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); - - a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); - - a10v.v = _mm256_mul_pd( a10v.v, setPlus.v ); - a11v.v = _mm256_mul_pd( a11v.v, setPlus.v ); - a12v.v = _mm256_mul_pd( a12v.v, setPlus.v ); - a13v.v = _mm256_mul_pd( a13v.v, setPlus.v ); - a14v.v = _mm256_mul_pd( a14v.v, setPlus.v ); - - a15v.v = _mm256_mul_pd( a10v.v, setMinus.v ); - a16v.v = _mm256_mul_pd( a11v.v, setMinus.v ); - a17v.v = _mm256_mul_pd( a12v.v, setMinus.v ); - a18v.v = _mm256_mul_pd( a13v.v, setMinus.v ); - a19v.v = _mm256_mul_pd( a14v.v, setMinus.v ); - - a15v.v = _mm256_permute_pd( a15v.v, 5 ); - a16v.v = _mm256_permute_pd( a16v.v, 5 ); - a17v.v = _mm256_permute_pd( a17v.v, 5 ); - a18v.v = _mm256_permute_pd( a18v.v, 5 ); - a19v.v = _mm256_permute_pd( a19v.v, 5 ); - - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); - - // For next 4 elements perform : y += alpha * x; - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); - - y1v.v = _mm256_fmadd_pd( a15v.v, chi5v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a16v.v, chi6v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a17v.v, chi7v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a18v.v, chi8v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a19v.v, chi9v.v, y1v.v ); - - // Store the output. - _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), y1v.v ); - - y0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - } - for( ; (i + 1) < m; i += 2 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*)(a0 + 0*n_elem_per_reg) ); - a01v.v = _mm256_loadu_pd( (double*)(a1 + 0*n_elem_per_reg) ); - a02v.v = _mm256_loadu_pd( (double*)(a2 + 0*n_elem_per_reg) ); - a03v.v = _mm256_loadu_pd( (double*)(a3 + 0*n_elem_per_reg) ); - a04v.v = _mm256_loadu_pd( (double*)(a4 + 0*n_elem_per_reg) ); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); - - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); - - // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); - - y0 += n_elem_per_reg ; - a0 += n_elem_per_reg ; - a1 += n_elem_per_reg ; - a2 += n_elem_per_reg ; - a3 += n_elem_per_reg ; - a4 += n_elem_per_reg ; - } - // If there are leftover iterations, perform them with scalar code. - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; - - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; - y0 += 1; - } - } - else - { - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; - - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - y0 += incy; - } - - } + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v ); + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v ); + + /** + * y0v.v holds computation with real part of chi. + * y2v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v. + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v ); + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + /** + * Final axpy compuation is available in r0v.v + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + + y0 += n_iter_unroll; + a0 += n_iter_unroll; + a1 += n_iter_unroll; + a2 += n_iter_unroll; + a3 += n_iter_unroll; + a4 += n_iter_unroll; + + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + + } } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index 99b544db15..cf7dbd1732 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,28 +97,6 @@ void bli_saxpyf_zen_int_6 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +119,7 @@ void bli_saxpyf_zen_int_6 cntx ); } -#endif + return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index b958600ce6..27dafb28fc 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -279,32 +279,19 @@ void bli_daxpyf_zen_int_8 const dim_t fuse_fac = 8; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 1; + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; dim_t i; - dim_t m_viter; - dim_t m_left; + dim_t m_viter[4]; + dim_t m_left = m; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; - double* restrict a5; - double* restrict a6; - double* restrict a7; + double* restrict av[8] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v, chi5v, chi6v, chi7v; + v4df_t chiv[8], a_vec[32], yv[4]; - v4df_t a0v, a1v, a2v, a3v; - v4df_t a4v, a5v, a6v, a7v; - v4df_t y0v; - - double chi0, chi1, chi2, chi3; - double chi4, chi5, chi6, chi7; + double chi[8] __attribute__((aligned(64))); // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return; @@ -343,94 +330,343 @@ void bli_daxpyf_zen_int_8 // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + m_viter[0] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[0] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[0] ); + + m_viter[1] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[1] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[1] ); + + m_viter[2] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[2] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[2] ); + + m_viter[3] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[3] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[3] ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override m_viter and m_left to use scalar code // for all iterations. if ( inca != 1 || incy != 1 ) { - m_viter = 0; + m_viter[0] = m_viter[1] = m_viter[2] = m_viter[3] = 0; m_left = m; } - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // av points to the 8 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; + av[5] = a + 5*lda; + av[6] = a + 6*lda; + av[7] = a + 7*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - chi5 = *( x + 5*incx ); - chi6 = *( x + 6*incx ); - chi7 = *( x + 7*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); + chi[5] = *( x + 5*incx ); + chi[6] = *( x + 6*incx ); + chi[7] = *( x + 7*incx ); // Scale each chi scalar by alpha. - PASTEMAC(d,scals)( *alpha, chi0 ); - PASTEMAC(d,scals)( *alpha, chi1 ); - PASTEMAC(d,scals)( *alpha, chi2 ); - PASTEMAC(d,scals)( *alpha, chi3 ); - PASTEMAC(d,scals)( *alpha, chi4 ); - PASTEMAC(d,scals)( *alpha, chi5 ); - PASTEMAC(d,scals)( *alpha, chi6 ); - PASTEMAC(d,scals)( *alpha, chi7 ); + PASTEMAC(d,scals)( *alpha, chi[0] ); + PASTEMAC(d,scals)( *alpha, chi[1] ); + PASTEMAC(d,scals)( *alpha, chi[2] ); + PASTEMAC(d,scals)( *alpha, chi[3] ); + PASTEMAC(d,scals)( *alpha, chi[4] ); + PASTEMAC(d,scals)( *alpha, chi[5] ); + PASTEMAC(d,scals)( *alpha, chi[6] ); + PASTEMAC(d,scals)( *alpha, chi[7] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); - chi5v.v = _mm256_broadcast_sd( &chi5 ); - chi6v.v = _mm256_broadcast_sd( &chi6 ); - chi7v.v = _mm256_broadcast_sd( &chi7 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); + chiv[5].v = _mm256_broadcast_sd( &chi[5] ); + chiv[6].v = _mm256_broadcast_sd( &chi[6] ); + chiv[7].v = _mm256_broadcast_sd( &chi[7] ); // If there are vectorized iterations, perform them with vector // instructions. - for ( i = 0; i < m_viter; ++i ) + // 16 elements of the result are computed per iteration + for ( i = 0; i < m_viter[0]; ++i ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + a_vec[24].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[25].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[26].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[27].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[28].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); + a_vec[29].v = _mm256_loadu_pd( av[5] + 3*n_elem_per_reg ); + a_vec[30].v = _mm256_loadu_pd( av[6] + 3*n_elem_per_reg ); + a_vec[31].v = _mm256_loadu_pd( av[7] + 3*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[24].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[25].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[26].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[27].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[28].v, chiv[4].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[29].v, chiv[5].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[30].v, chiv[6].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[31].v, chiv[7].v, yv[3].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + av[4] += n_elem_per_reg * n_iter_unroll[0]; + av[5] += n_elem_per_reg * n_iter_unroll[0]; + av[6] += n_elem_per_reg * n_iter_unroll[0]; + av[7] += n_elem_per_reg * n_iter_unroll[0]; + } + + // 12 elements of the result are computed per iteration + for ( i = 0; i < m_viter[1]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + av[4] += n_elem_per_reg * n_iter_unroll[1]; + av[5] += n_elem_per_reg * n_iter_unroll[1]; + av[6] += n_elem_per_reg * n_iter_unroll[1]; + av[7] += n_elem_per_reg * n_iter_unroll[1]; + } + + // 8 elements of the result are computed per iteration + for ( i = 0; i < m_viter[2]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + av[4] += n_elem_per_reg * n_iter_unroll[2]; + av[5] += n_elem_per_reg * n_iter_unroll[2]; + av[6] += n_elem_per_reg * n_iter_unroll[2]; + av[7] += n_elem_per_reg * n_iter_unroll[2]; + } + + // 4 elements of the result are computed per iteration + for ( i = 0; i < m_viter[3]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; - a5 += n_elem_per_reg; - a6 += n_elem_per_reg; - a7 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; + av[5] += n_elem_per_reg; + av[6] += n_elem_per_reg; + av[7] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -438,34 +674,34 @@ void bli_daxpyf_zen_int_8 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; - - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; - y0c += chi5 * a5c; - y0c += chi6 * a6c; - y0c += chi7 * a7c; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; + const double a5c = *av[5]; + const double a6c = *av[6]; + const double a7c = *av[7]; + + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; + y0c += chi[5] * a5c; + y0c += chi[6] * a6c; + y0c += chi[7] * a7c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; + av[5] += inca; + av[6] += inca; + av[7] += inca; y0 += incy; } } diff --git a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c new file mode 100644 index 0000000000..1be9975ecf --- /dev/null +++ b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c @@ -0,0 +1,1561 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "immintrin.h" + +typedef union{ + __m256d v; + double d[4] __attribute__((aligned(64))); +}vec; + +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/** + * bli_pre_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for lower triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. + +*/ +void bli_pre_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + /*Broadcast alpha*/ + ymm9 = _mm256_broadcast_sd(alpha); + + /** + * Scaling vector x with alpha + * to gather alpha_chi elements + * arranged in one buffer. + */ + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + /*Load y vector*/ + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + //Col 0 computation + /*Broadcasts chi and multiplies with alpha to get alpha chi*/ + ymm12 = _mm256_broadcast_sd(alpha_chi); + /*Load first column of A matrix*/ + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_loadu_pd(a + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 1 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + /** + * pack the data in following manner into ymm register + * Since it is computing 2nd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 + * --- --- + x x + --- x + --- x + */ + ymm3 = _mm256_broadcast_sd(a + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 2 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + /** + * pack the data in following manner into ymm register + * Since it is computing 3rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 2); + ymm4 = _mm256_broadcast_sd(a + 2 + cs_a); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 3 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 3); + ymm4 = _mm256_broadcast_sd(a + 3 + cs_a); + ymm5 = _mm256_broadcast_sd(a + 3 + cs_a * 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm5, 0x4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Transpose 4x4 tile of matrix A, + * for remainder column computation. + */ + ymm0 = _mm256_loadu_pd(a+4 + cs_a * 0); + ymm1 = _mm256_loadu_pd(a+4 + cs_a * 1); + ymm2 = _mm256_loadu_pd(a+4 + cs_a * 2); + ymm3 = _mm256_loadu_pd(a+4 + cs_a * 3); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //Transposed col 1 + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //Transposed col 3 + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //Transposed col 2 + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //Transposed col 4 + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 col-4 + * --- --- --- --- --- + x x x x --- + --- --- --- --- --- + --- --- --- --- --- + */ + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm6, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 5th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 + * --- --- + x x + --- x + --- x + + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm3 = _mm256_broadcast_sd(a + 5 + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm7, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 6th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 6); + ymm3 = _mm256_broadcast_sd(a + 6 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 6 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm8, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 7th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 col-7 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 7); + ymm3 = _mm256_broadcast_sd(a + 7 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 7 + cs_a * 5); + ymm5 = _mm256_broadcast_sd(a + 7 + cs_a * 6); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm9, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * bli_post_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for upper triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. +*/ +void bli_post_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + + ymm9 = _mm256_broadcast_sd(alpha); + + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 5); + ymm2 = _mm256_loadu_pd(a + cs_a * 6); + ymm3 = _mm256_loadu_pd(a + cs_a * 7); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + //Col 0 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * x x x x + --- + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi); + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_broadcast_sd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm1, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm6, ymm11); + + //Col 1 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-1 col-2 col-3 + * x x x + x + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2 + 1); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3 + 1); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm7, ymm11); + + //Col 2 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-2 col-3 + * x x + x + x + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm2 = _mm256_broadcast_sd(a + cs_a * 3 + 2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm8, ymm11); + + //Col 3 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-3 + * x + x + x + x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm9, ymm11); + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 4 + 4); + ymm4 = _mm256_broadcast_sd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 4); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm0 = _mm256_loadu_pd(a + cs_a * 5); + ymm1 = _mm256_loadu_pd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 5); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 5); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm0 = _mm256_loadu_pd(a + cs_a * 6); + ymm1 = _mm256_loadu_pd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 6); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm0 = _mm256_loadu_pd(a + cs_a * 7); + ymm1 = _mm256_loadu_pd(a + cs_a * 7 + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * ddotxaxpyf kernel performs dot and apxy function all togather + * on a tile of 4x8 size. + * x_trsv holds 4 elements of vector x, a_tile[0-7] holds + * 4x8 tile of A matrix. + * Following equations are solved in a way represented + * y1 = y1 + alpha * A21' * x2; (dotxf) + y2 = y2 + alpha * A21 * x1; (axpyf) + + * B1 B2 B3 B4 B5 B6 B7 B8 + * (broadcast elements of [x*alpha] vector) + * tile 0 1 2 3 4 5 6 7 + * x_trsv[0] A00 A01 A02 A03 => rho0 | A04 A05 A06 A07 => rho4 + * x_trsv[1] A10 A11 A12 A13 => rho1 | A14 A15 A16 A17 => rho5 + * x_trsv[2] A20 A21 A22 A23 => rho2 | A24 A25 A26 A27 => rho6 + * x_trsv[3] A30 A31 A32 A33 => rho3 | A34 A35 A36 A37 => rho7 + || || || || || || || || + \/ \/ \/ \/ \/ \/ \/ \/ + += += += += += += += += + z_vec z_vec z_vec z_vec z_vec z_vec z_vec z_vec + * + * + */ +void bli_ddotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict w, inc_t incw, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + /* A is m x n. */ + /* y = beta * y + alpha * A^T w; */ + /* z = z + alpha * A x; */ + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + (inca == 1) && (incw == 1) && (incx == 1) + && (incy == 1) && (incz == 1) && (b_n == 8) ) + { + __m256d r0, r1; + r0 = _mm256_setzero_pd(); + r1 = _mm256_setzero_pd(); + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) return; + + dim_t row = 0; + dim_t iter = m/4; + dim_t rem = m%4; + if(iter) + { + vec x_trsv, x_hemvB1, x_hemvB2, x_hemvB3, x_hemvB4; + vec x_hemvB5, x_hemvB6, x_hemvB7, x_hemvB8; + + vec a_tile0, a_tile1, a_tile2, a_tile3; + vec a_tile4, a_tile5, a_tile6, a_tile7; + + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + __m256d z_vec; + + /** + * Load [x vector * alpha], broadcast each element into + * different ymm registers. To perform axpyf operation + * with 4x8 tile of A matrix. + */ + + x_hemvB1.v = _mm256_set1_pd(x[0*incx] * (*alpha)); + x_hemvB2.v = _mm256_set1_pd(x[1*incx] * (*alpha)); + x_hemvB3.v = _mm256_set1_pd(x[2*incx] * (*alpha)); + x_hemvB4.v = _mm256_set1_pd(x[3*incx] * (*alpha)); + + x_hemvB5.v = _mm256_set1_pd(x[4*incx] * (*alpha)); + x_hemvB6.v = _mm256_set1_pd(x[5*incx] * (*alpha)); + x_hemvB7.v = _mm256_set1_pd(x[6*incx] * (*alpha)); + x_hemvB8.v = _mm256_set1_pd(x[7*incx] * (*alpha)); + + /** + * clear rho register which holds result of + * fmadds for dotxf operation. + * Once micro tile is computed, horizontal addition + * of all rho's will provide us with the result of + * dotxf opereation. + */ + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + for(; (row + 3) < m; row+= 4) + { + a_tile0.v = _mm256_loadu_pd((double *) + &a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd((double *) + &a[row + 1 * lda] ); + a_tile2.v = _mm256_loadu_pd((double *) + &a[row + 2 * lda] ); + a_tile3.v = _mm256_loadu_pd((double *) + &a[row + 3 * lda] ); + a_tile4.v = _mm256_loadu_pd((double *) + &a[row + 4 * lda] ); + a_tile5.v = _mm256_loadu_pd((double *) + &a[row + 5 * lda] ); + a_tile6.v = _mm256_loadu_pd((double *) + &a[row + 6 * lda] ); + a_tile7.v = _mm256_loadu_pd((double *) + &a[row + 7 * lda] ); + + x_trsv.v = _mm256_loadu_pd((double *) &w[row]); + z_vec = _mm256_loadu_pd((double *) &z[row] ); + + //dot product operation + rho0.v = _mm256_fmadd_pd(a_tile0.v, + x_trsv.v, rho0.v); + rho4.v = _mm256_fmadd_pd(a_tile4.v, + x_trsv.v, rho4.v); + + rho1.v = _mm256_fmadd_pd(a_tile1.v, + x_trsv.v, rho1.v); + rho5.v = _mm256_fmadd_pd(a_tile5.v, + x_trsv.v, rho5.v); + + rho2.v = _mm256_fmadd_pd(a_tile2.v, + x_trsv.v, rho2.v); + rho6.v = _mm256_fmadd_pd(a_tile6.v, + x_trsv.v, rho6.v); + + rho3.v = _mm256_fmadd_pd(a_tile3.v, + x_trsv.v, rho3.v); + rho7.v = _mm256_fmadd_pd(a_tile7.v, + x_trsv.v, rho7.v); + + //axpy operation + z_vec = _mm256_fmadd_pd(a_tile0.v, + x_hemvB1.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile1.v, + x_hemvB2.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile2.v, + x_hemvB3.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile3.v, + x_hemvB4.v, z_vec); + + z_vec = _mm256_fmadd_pd(a_tile4.v, + x_hemvB5.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile5.v, + x_hemvB6.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile6.v, + x_hemvB7.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile7.v, + x_hemvB8.v, z_vec); + + _mm256_storeu_pd((double *)&z[row], z_vec); + } + /*Horizontal addition of rho's elements to compute + * the final dotxf result. + */ + rho0.v = _mm256_hadd_pd( rho0.v, rho1.v ); + rho2.v = _mm256_hadd_pd( rho2.v, rho3.v ); + rho4.v = _mm256_hadd_pd( rho4.v, rho5.v ); + rho6.v = _mm256_hadd_pd( rho6.v, rho7.v ); + + { + __m128d xmm0, xmm1; + + xmm0 = _mm256_extractf128_pd(rho0.v, 0); + xmm1 = _mm256_extractf128_pd(rho0.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho2.v, 0); + xmm1 = _mm256_extractf128_pd(rho2.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 1); + + + xmm0 = _mm256_extractf128_pd(rho4.v, 0); + xmm1 = _mm256_extractf128_pd(rho4.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho6.v, 0); + xmm1 = _mm256_extractf128_pd(rho6.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 1); + } + } + if(rem) + { + double r[ 8 ]; + double ax[ 8 ]; + /** + * Computed dot product computation needs + * to be brought into the r buffer for + * corner cases, so that remainder computation + * can be updated in r buffer. + */ + _mm256_storeu_pd((double *)r, r0); + _mm256_storeu_pd( (double *)(r + 4), r1); + + PRAGMA_SIMD + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scal2s) + ( *alpha, x[i], ax[i] ); + } + + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,axpys) + ( a[p + i*lda], + w[p], r[i] ); + PASTEMAC(d,axpyjs) + ( ax[i], + a[p + i*lda], z[p] ); + } + } + /** + * Final dot product computation needs be + * loaded into registers, for getting + * scaled by Alpha and finally be stored + * back into output vector. + */ + r0 = _mm256_loadu_pd((double const *)r); + r1 = _mm256_loadu_pd((double const *)(r + 4)); + } + + /** + * Storing the computed result after being + * scaled by Alpha into output vector. + */ + { + __m256d y0, y1, Alpha; + y0 = _mm256_loadu_pd(y); + y1 = _mm256_loadu_pd(y + 4); + Alpha = _mm256_broadcast_sd(alpha); + y0 = _mm256_fmadd_pd(Alpha, r0, y0); + y1 = _mm256_fmadd_pd(Alpha, r1, y1); + _mm256_storeu_pd(y, y0); + _mm256_storeu_pd(y+4, y1); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(d,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} + +/** + * zdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_zdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict w, inc_t incw, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + dcomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + dcomplex chi0; + dcomplex chi1; + dcomplex chi2; + dcomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(z,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_zcopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_zcopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_zcopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_zcopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + + dim_t row = 0; + dim_t iter = m / 2; + dim_t rem = m % 2; + if (iter) + { + vec x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + vec x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + vec a_tile0, a_tile1; // a_tile? holds columns of a + vec temp1, temp2, temp3; // temp? registers for intermediate op + vec wR, wI; // holds real & imag components of w + vec z_vec; // holds the z vector + + // rho? registers hold results of fmadds for dotxf operation + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256d no_conju = _mm256_setr_pd( -1, 1, -1, 1 ); + __m256d conju = _mm256_setr_pd( 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_pd(); + temp2.v = _mm256_setzero_pd(); + temp3.v = _mm256_setzero_pd(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_sd( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_sd( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_sd( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_sd( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_sd( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_sd( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_sd( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_sd( &chi3.imag ); // imag part of x3 + + for ( ; ( row + 1 ) < m; row += 2) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I + // a_tile1.v -> a01R a01I a11R a11I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 1 * lda] ); + + temp1.v = _mm256_mul_pd( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_pd( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R + wR.v = _mm256_loadu_pd( (double *)&w[row] ); + wI.v = _mm256_permute_pd( wR.v, 15 ); + wR.v = _mm256_permute_pd( wR.v, 0 ); + + rho0.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho0.v); + rho4.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho4.v); + + rho1.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho1.v); + rho5.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho5.v); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 2 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 3 * lda] ); + + temp1.v = _mm256_fmadd_pd( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x3I.v, temp2.v ); + + rho2.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho2.v); + rho6.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho6.v); + + rho3.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho3.v); + rho7.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho7.v); + + // Load z vector + z_vec.v = _mm256_loadu_pd( (double *)&z[row] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_pd(temp2.v, 5); + temp3.v = _mm256_addsub_pd(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_pd( temp1.v, 5 ); + temp3.v = _mm256_addsub_pd( temp2.v, temp1.v ); + temp3.v = _mm256_permute_pd( temp3.v, 5 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_pd( temp3.v, z_vec.v ); + _mm256_storeu_pd( (double *)&z[row], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I + + rho4.v = _mm256_permute_pd(rho4.v, 0x05); + rho5.v = _mm256_permute_pd(rho5.v, 0x05); + rho6.v = _mm256_permute_pd(rho6.v, 0x05); + rho7.v = _mm256_permute_pd(rho7.v, 0x05); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4.v = _mm256_mul_pd(rho4.v, no_conju); + rho5.v = _mm256_mul_pd(rho5.v, no_conju); + rho6.v = _mm256_mul_pd(rho6.v, no_conju); + rho7.v = _mm256_mul_pd(rho7.v, no_conju); + } + else + { + rho4.v = _mm256_mul_pd(rho4.v, conju); + rho5.v = _mm256_mul_pd(rho5.v, conju); + rho6.v = _mm256_mul_pd(rho6.v, conju); + rho7.v = _mm256_mul_pd(rho7.v, conju); + } + + rho0.v = _mm256_add_pd(rho0.v, rho4.v); + rho1.v = _mm256_add_pd(rho1.v, rho5.v); + rho2.v = _mm256_add_pd(rho2.v, rho6.v); + rho3.v = _mm256_add_pd(rho3.v, rho7.v); + + // rho0 & rho1 hold final dot product + // result of 4 dcomplex elements + rho0.d[0] += rho0.d[2]; + rho0.d[1] += rho0.d[3]; + + rho0.d[2] = rho1.d[0] + rho1.d[2]; + rho0.d[3] = rho1.d[1] + rho1.d[3]; + + rho1.d[0] = rho2.d[0] + rho2.d[2]; + rho1.d[1] = rho2.d[1] + rho2.d[3]; + + rho1.d[2] = rho3.d[0] + rho3.d[2]; + rho1.d[3] = rho3.d[1] + rho3.d[3]; + + // Storing the computed dot product + // in temp buffer rho for further computation. + _mm256_storeu_pd( (double *)rho, rho0.v ); + _mm256_storeu_pd( (double *)(rho+2) , rho1.v ); + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + const dcomplex a0c = a[p + 0 * lda]; + const dcomplex a1c = a[p + 1 * lda]; + const dcomplex a2c = a[p + 2 * lda]; + const dcomplex a3c = a[p + 3 * lda]; + + // dot + dcomplex r0c = rho[0]; + dcomplex r1c = rho[1]; + dcomplex r2c = rho[2]; + dcomplex r3c = rho[3]; + + dcomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + dcomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,conjs)( rho[i] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,axpys)( *alpha, rho[i], y[i] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(z,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} + +/** + * cdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_cdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict w, inc_t incw, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + scomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + scomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(c,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_ccopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_ccopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_ccopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_ccopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + dim_t i = 0; + dim_t iter = m / 4; + dim_t rem = m % 4; + if (iter) + { + v8sf_t x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + v8sf_t x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + v8sf_t a_tile0, a_tile1; // a_tile? holds columns of a + v8sf_t temp1, temp2, temp3; // temp? registers for intermediate op + v8sf_t wR, wI; // holds real & imag components of w + v8sf_t z_vec; // holds the z vector + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256 no_conju = _mm256_setr_ps( -1, 1, -1, 1, -1, 1, -1, 1 ); + __m256 conju = _mm256_setr_ps( 1, -1, 1, -1, 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_ps(); + temp2.v = _mm256_setzero_ps(); + temp3.v = _mm256_setzero_ps(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_ss( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_ss( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_ss( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_ss( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_ss( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_ss( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_ss( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_ss( &chi3.imag ); // imag part of x3 + + for ( ; ( i + 3 ) < m; i += 4) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I a20R a20I a30R a30I + // a_tile1.v -> a01R a01I a11R a11I a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 0 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 1 * lda] ); + + temp1.v = _mm256_mul_ps( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_ps( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I w2R w2I w3R w3I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I w2I w2I w3I w3I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R w2R w2R w3R w3R + wR.v = _mm256_loadu_ps( (float *) (w + i) ); + wI.v = _mm256_permute_ps( wR.v, 0xf5 ); + wR.v = _mm256_permute_ps( wR.v, 0xa0); + + rho0v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho0v ); + rho4v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho4v ); + + rho1v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho1v ); + rho5v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho5v ); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 2 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 3 * lda] ); + + temp1.v = _mm256_fmadd_ps( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x3I.v, temp2.v ); + + rho2v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho2v ); + rho6v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho6v ); + + rho3v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho3v ); + rho7v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho7v ); + + // Load z vector + z_vec.v = _mm256_loadu_ps( (float *)&z[i] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_ps(temp2.v, 0xB1); + temp3.v = _mm256_addsub_ps(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_ps( temp1.v, 0xB1 ); + temp3.v = _mm256_addsub_ps( temp2.v, temp1.v ); + temp3.v = _mm256_permute_ps( temp3.v, 0xB1 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_ps( temp3.v, z_vec.v ); + _mm256_storeu_ps( (float *)&z[i], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R | ... + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I | ... + + rho4v = _mm256_permute_ps(rho4v, 0xb1); + rho5v = _mm256_permute_ps(rho5v, 0xb1); + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4v = _mm256_mul_ps(rho4v, no_conju); + rho5v = _mm256_mul_ps(rho5v, no_conju); + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + } + else + { + rho4v = _mm256_mul_ps(rho4v, conju); + rho5v = _mm256_mul_ps(rho5v, conju); + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + } + + rho0v = _mm256_add_ps(rho0v, rho4v); + rho1v = _mm256_add_ps(rho1v, rho5v); + rho2v = _mm256_add_ps(rho2v, rho6v); + rho3v = _mm256_add_ps(rho3v, rho7v); + + // Horizontal addition of rho elements for computing final dotxf + // and storing the results into rho buffer + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t j = 0; j < 4; j++) + { + rho[0].real += ptr[j].real; + rho[0].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t j = 0; j < 4; j++) + { + rho[1].real += ptr[j].real; + rho[1].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t j = 0; j < 4; j++) + { + rho[2].real += ptr[j].real; + rho[2].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t j = 0; j < 4; j++) + { + rho[3].real += ptr[j].real; + rho[3].imag += ptr[j].imag; + } + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = i; p < m; ++p ) + { + const scomplex a0c = a[p + 0 * lda]; + const scomplex a1c = a[p + 1 * lda]; + const scomplex a2c = a[p + 2 * lda]; + const scomplex a3c = a[p + 3 * lda]; + + // dot + scomplex r0c = rho[0]; + scomplex r1c = rho[1]; + scomplex r2c = rho[2]; + scomplex r3c = rho[3]; + + scomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + scomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,conjs)( rho[j] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,axpys)( *alpha, rho[j], y[j] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(c,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} \ No newline at end of file diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 531a389b50..815e388f21 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2017 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,6 +52,14 @@ typedef union double d[4] __attribute__((aligned(64))); } v4df_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + // ----------------------------------------------------------------------------- void bli_sdotxf_zen_int_8 @@ -430,49 +438,46 @@ void bli_ddotxf_zen_int_8 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 8; - const dim_t n_elem_per_reg = 4; + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 4; // If the b_n dimension is zero, y is empty and there is no computation. - if ( bli_zero_dim1( b_n ) ) return; + if (bli_zero_dim1(b_n)) + return; // If the m dimension is zero, or if alpha is zero, the computation // simplifies to updating y. - if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) { - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - b_n, - beta, - y, incy, - cntx - ); + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); return; } // If b_n is not equal to the fusing factor, then perform the entire // operation as a loop over dotxv. - if ( b_n != fuse_fac ) + if (b_n != fuse_fac) { - for ( dim_t i = 0; i < b_n; ++i ) + for (dim_t i = 0; i < b_n; ++i) { - double* a1 = a + (0 )*inca + (i )*lda; - double* x1 = x + (0 )*incx; - double* psi1 = y + (i )*incy; - - bli_ddotxv_zen_int - ( - conjat, - conjx, - m, - alpha, - a1, inca, - x1, incx, - beta, - psi1, - cntx - ); + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); } return; } @@ -493,115 +498,113 @@ void bli_ddotxf_zen_int_8 // distinguishes between (1) and (2). // Intermediate variables to hold the completed dot products - double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, - rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + double rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; - if ( inca == 1 && incx == 1 ) + if (inca == 1 && incx == 1) { const dim_t n_iter_unroll = 1; // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); + dim_t m_viter; + + // Calculate the number of vector iterations that can occur + // for the given unroll factors. + m_viter = (m) / (n_elem_per_reg * n_iter_unroll); // Set up pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + double *restrict x0 = x; + double *restrict av[8]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + av[4] = a + 4 * lda; + av[5] = a + 5 * lda; + av[6] = a + 6 * lda; + av[7] = a + 7 * lda; // Initialize b_n rho vector accumulators to zero. - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rhov[8]; - v4df_t x0v; - v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); - // If there are vectorized iterations, perform them with vector - // instructions. - for ( dim_t i = 0; i < m_viter; ++i ) + v4df_t xv; + v4df_t avec[8]; + + for (dim_t i = 0; i < m_viter; ++i) { // Load the input values. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv.v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv.v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv.v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv.v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv.v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[4] + 0 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[5] + 0 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[6] + 0 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[7] + 0 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv.v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv.v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv.v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv.v, rhov[7].v); x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; + av[0] += n_elem_per_reg * n_iter_unroll; + av[1] += n_elem_per_reg * n_iter_unroll; + av[2] += n_elem_per_reg * n_iter_unroll; + av[3] += n_elem_per_reg * n_iter_unroll; + av[4] += n_elem_per_reg * n_iter_unroll; + av[5] += n_elem_per_reg * n_iter_unroll; + av[6] += n_elem_per_reg * n_iter_unroll; + av[7] += n_elem_per_reg * n_iter_unroll; } -#if 0 - rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; - rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; - rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; - rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; - rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; - rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; - rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; - rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; -#else // Sum the elements of a given rho?v. This computes the sum of // elements within lanes and stores the sum to both elements. - rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); - rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); - rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); - rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); - rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); - rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); - rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + rhov[4].v = _mm256_hadd_pd(rhov[4].v, rhov[4].v); + rhov[5].v = _mm256_hadd_pd(rhov[5].v, rhov[5].v); + rhov[6].v = _mm256_hadd_pd(rhov[6].v, rhov[6].v); + rhov[7].v = _mm256_hadd_pd(rhov[7].v, rhov[7].v); // Manually add the results from above to finish the sum. - rho0 = rho0v.d[0] + rho0v.d[2]; - rho1 = rho1v.d[0] + rho1v.d[2]; - rho2 = rho2v.d[0] + rho2v.d[2]; - rho3 = rho3v.d[0] + rho3v.d[2]; - rho4 = rho4v.d[0] + rho4v.d[2]; - rho5 = rho5v.d[0] + rho5v.d[2]; - rho6 = rho6v.d[0] + rho6v.d[2]; - rho7 = rho7v.d[0] + rho7v.d[2]; -#endif + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + rho4 = rhov[4].d[0] + rhov[4].d[2]; + rho5 = rhov[5].d[0] + rhov[5].d[2]; + rho6 = rhov[6].d[0] + rhov[6].d[2]; + rho7 = rhov[7].d[0] + rhov[7].d[2]; + // Adjust for scalar subproblem. m -= n_elem_per_reg * n_iter_unroll * m_viter; a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; - } - else if ( lda == 1 ) + + }else if (lda == 1) { const dim_t n_iter_unroll = 3; const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg; @@ -672,127 +675,1771 @@ void bli_ddotxf_zen_int_8 a += n_iter_unroll * m_viter * inca; x += n_iter_unroll * m_viter * incx; } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + double *restrict a4 = a + 4 * lda; + double *restrict a5 = a + 5 * lda; + double *restrict a6 = a + 6 * lda; + double *restrict a7 = a + 7 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + const double a5c = *a5; + const double a6c = *a6; + const double a7c = *a7; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v4df_t rho0v, rho1v, y0v, y1v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + rho0v.d[2] = rho2; + rho0v.d[3] = rho3; + rho1v.d[0] = rho4; + rho1v.d[1] = rho5; + rho1v.d[2] = rho6; + rho1v.d[3] = rho7; + + // Broadcast the alpha scalar. + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); + y1v.v = _mm256_mul_pd(alphav.v, rho1v.v); + } else { - // No vectorization possible; use scalar iterations for the entire - // problem. + // Broadcast the beta scalar. + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); + y1v.v = _mm256_loadu_pd(y + 1 * n_elem_per_reg); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); + y1v.d[0] = *(y + 4 * incy); + y1v.d[1] = *(y + 5 * incy); + y1v.d[2] = *(y + 6 * incy); + y1v.d[3] = *(y + 7 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y1v.v = _mm256_mul_pd(betav.v, y1v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); + y1v.v = _mm256_fmadd_pd(alphav.v, rho1v.v, y1v.v); } - // Scalar edge case. + if (incy == 1) { - // Initialize pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + // Store the output. + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + _mm256_storeu_pd((y + 1 * n_elem_per_reg), y1v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + *(y + 4 * incy) = y1v.d[0]; + *(y + 5 * incy) = y1v.d[1]; + *(y + 6 * incy) = y1v.d[2]; + *(y + 7 * incy) = y1v.d[3]; + } +} - // If there are leftover iterations, perform them with scalar code. - for ( dim_t i = 0; i < m ; ++i ) + +void bli_ddotxf_zen_int_4 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 4; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) { - const double x0c = *x0; + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], m_left = m, i; + + // Calculate the number of vector iterations that can occur for + // various unroll factors. + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[4]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[16]; + + // If there are vectorized iterations, perform them with vector + // instructions. + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + avec[12].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[13].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + avec[14].v = _mm256_loadu_pd(av[2] + 3 * n_elem_per_reg); + avec[15].v = _mm256_loadu_pd(av[3] + 3 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[12].v, xv[3].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[13].v, xv[3].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[14].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[15].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + } + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + } + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + av[2] += n_elem_per_reg * n_iter_unroll[3]; + av[3] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; } } + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + // Now prepare the final rho values to output/accumulate back into // the y vector. - v4df_t rho0v, rho1v, y0v, y1v; + v4df_t rho0v, y0v; // Insert the scalar rho values into a single vector. rho0v.d[0] = rho0; rho0v.d[1] = rho1; rho0v.d[2] = rho2; rho0v.d[3] = rho3; - rho1v.d[0] = rho4; - rho1v.d[1] = rho5; - rho1v.d[2] = rho6; - rho1v.d[3] = rho7; // Broadcast the alpha scalar. - v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha ); + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale // by beta (in case y contains NaN or Inf). - if ( PASTEMAC(d,eq0)( *beta ) ) + if (PASTEMAC(d, eq0)(*beta)) { // Apply alpha to the accumulated dot product in rho: // y := alpha * rho - y0v.v = _mm256_mul_pd( alphav.v, rho0v.v ); - y1v.v = _mm256_mul_pd( alphav.v, rho1v.v ); + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); } else { // Broadcast the beta scalar. - v4df_t betav; betav.v = _mm256_broadcast_sd( beta ); + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); // Load y. - if ( incy == 1 ) + if (incy == 1) { - y0v.v = _mm256_loadu_pd( y + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); } else { - y0v.d[0] = *(y + 0*incy); y0v.d[1] = *(y + 1*incy); - y0v.d[2] = *(y + 2*incy); y0v.d[3] = *(y + 3*incy); - y1v.d[0] = *(y + 4*incy); y1v.d[1] = *(y + 5*incy); - y1v.d[2] = *(y + 6*incy); y1v.d[3] = *(y + 7*incy); + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); } // Apply beta to y and alpha to the accumulated dot product in rho: // y := beta * y + alpha * rho - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y0v.v = _mm256_fmadd_pd( alphav.v, rho0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( alphav.v, rho1v.v, y1v.v ); + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); } - if ( incy == 1 ) + if (incy == 1) + { + // Store the output. + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + } +} + +void bli_ddotxf_zen_int_2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 2; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {8, 4, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], i, m_left = m; + + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[2]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[8]; + + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 4 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 4 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 5 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 5 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 6 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 6 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 7 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 7 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[3].v); + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; + } + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v2df_t rho0v, y0v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + + // Broadcast the alpha scalar. + v2df_t alphav; + + alphav.v = _mm_load1_pd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_pd(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v2df_t betav; + betav.v = _mm_load1_pd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm_loadu_pd(y + 0 * 2); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_pd(betav.v, y0v.v); + y0v.v = _mm_fmadd_pd(alphav.v, rho0v.v, y0v.v); + } + + if (incy == 1) { // Store the output. - _mm256_storeu_pd( (y + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (y + 1*n_elem_per_reg), y1v.v ); + _mm_storeu_pd((y + 0 * 2), y0v.v); } else { - *(y + 0*incy) = y0v.d[0]; *(y + 1*incy) = y0v.d[1]; - *(y + 2*incy) = y0v.d[2]; *(y + 3*incy) = y0v.d[3]; - *(y + 4*incy) = y1v.d[0]; *(y + 5*incy) = y1v.d[1]; - *(y + 6*incy) = y1v.d[2]; *(y + 7*incy) = y1v.d[3]; + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + } +} + +/** + * Performs dotxf operation on dcomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 2. + */ +void bli_zdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /** + * Handles only unit stride cases and 6 column at a time + * b_n check for columns to be 6. + */ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + dcomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(z,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first + * toggling the effective conjugation of x and then conjugating + * the resulting do products. + * Rather conjugating each element of a matrix, final computed result + * can be conjugated at the end of loop. This takes off the overhead + * of conjugating each element inside the loop and improves the + * performance. + */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjx_use ); + } + + /* Setting rho vectors to 0 */ + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + + v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); + v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rho8v; rho8v.v = _mm256_setzero_pd(); + v4df_t rho9v; rho9v.v = _mm256_setzero_pd(); + v4df_t rho10v; rho10v.v = _mm256_setzero_pd(); + v4df_t rho11v; rho11v.v = _mm256_setzero_pd(); + + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + v4df_t x0v, x1v; + /* Holds 2x6 tile of matrix A */ + v4df_t a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256d no_conju = _mm256_setr_pd(-1, 1, -1, 1); + __m256d conju = _mm256_setr_pd(1, -1, 1, -1); + dim_t iter = m / 2; + dim_t rem = m % 2; + dim_t i = 0; + + if ( bli_is_noconj( conjx_use ) ) + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( + (double *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => I0 I0 I1 I1 + * + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => R0 R0 R1 R1 + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda) ); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda) ); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda) ); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda) ); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda) ); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda) ); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, no_conju); + rho7v.v = _mm256_mul_pd(rho7v.v, no_conju); + rho8v.v = _mm256_mul_pd(rho8v.v, no_conju); + rho9v.v = _mm256_mul_pd(rho9v.v, no_conju); + rho10v.v = _mm256_mul_pd(rho10v.v, no_conju); + rho11v.v = _mm256_mul_pd(rho11v.v, no_conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + /*handles remainder cases*/ + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpys)( a[i + p*lda] + , x[i], r[p] ); + } + } + } + else + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( (double *) + (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda)); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda)); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda)); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda)); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda)); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, conju); + rho7v.v = _mm256_mul_pd(rho7v.v, conju); + rho8v.v = _mm256_mul_pd(rho8v.v, conju); + rho9v.v = _mm256_mul_pd(rho9v.v, conju); + rho10v.v = _mm256_mul_pd(rho10v.v, conju); + rho11v.v = _mm256_mul_pd(rho11v.v, conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpyjs)(a[i + p*lda] + , x[i], r[p] ); + } + } + } + + if ( bli_is_conj( conjat ) ) + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,conjs)( r[i] ); + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + dcomplex* restrict a1 = a + (0 )*inca + (i )*lda; + dcomplex* restrict x1 = x + (0 )*incx; + dcomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + } + +} + + +/** + * Performs dotxf operation on scomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 4 and 2. + */ +void bli_cdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + scomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early. */ + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(c,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first toggling the + effective conjugation of x and then conjugating the resulting do + products. */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + bli_toggle_conj( &conjx_use ); + + dim_t iter = m / 2; + dim_t iter4 = m / 4; + dim_t rem = m % 2; + dim_t i = 0; + if(iter) + { + if(iter4) + { + /* Setting rho vectors to 0 */ + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + __m256 rho8v; rho8v = _mm256_setzero_ps(); + __m256 rho9v; rho9v = _mm256_setzero_ps(); + __m256 rho10v; rho10v = _mm256_setzero_ps(); + __m256 rho11v; rho11v = _mm256_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m256 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m256 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256 no_conju = _mm256_setr_ps(-1, 1, -1, 1, -1, 1, -1, 1); + __m256 conju = _mm256_setr_ps(1, -1, 1, -1, 1, -1, 1, -1); + + // March through vectos in multiple of 4. + for ( ; (i+3) < m; i+=4) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm256_loadu_ps( (float *) (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm256_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm256_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + a0v = _mm256_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm256_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm256_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm256_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm256_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm256_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm256_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm256_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm256_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm256_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm256_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm256_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm256_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm256_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm256_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm256_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm256_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm256_fmadd_ps( a5v, x1v, rho11v ); + } + + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + rho8v = _mm256_permute_ps(rho8v, 0xb1); + rho9v = _mm256_permute_ps(rho9v, 0xb1); + rho10v = _mm256_permute_ps(rho10v, 0xb1); + rho11v = _mm256_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + rho8v = _mm256_mul_ps(rho8v, no_conju); + rho9v = _mm256_mul_ps(rho9v, no_conju); + rho10v = _mm256_mul_ps(rho10v, no_conju); + rho11v = _mm256_mul_ps(rho11v, no_conju); + } + else + { + + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + rho8v = _mm256_mul_ps(rho8v, conju); + rho9v = _mm256_mul_ps(rho9v, conju); + rho10v = _mm256_mul_ps(rho10v, conju); + rho11v = _mm256_mul_ps(rho11v, conju); + + } + + rho0v = _mm256_add_ps(rho0v, rho6v); + rho1v = _mm256_add_ps(rho1v, rho7v); + rho2v = _mm256_add_ps(rho2v, rho8v); + rho3v = _mm256_add_ps(rho3v, rho9v); + rho4v = _mm256_add_ps(rho4v, rho10v); + rho5v = _mm256_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 4; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 4; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 4; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 4; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 4; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 4; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + // March through vectos in multiple of 2. + if(i+1 < m) + { + /* Setting rho vectors to 0 */ + __m128 rho0v; rho0v = _mm_setzero_ps(); + __m128 rho1v; rho1v = _mm_setzero_ps(); + __m128 rho2v; rho2v = _mm_setzero_ps(); + __m128 rho3v; rho3v = _mm_setzero_ps(); + __m128 rho4v; rho4v = _mm_setzero_ps(); + __m128 rho5v; rho5v = _mm_setzero_ps(); + + __m128 rho6v; rho6v = _mm_setzero_ps(); + __m128 rho7v; rho7v = _mm_setzero_ps(); + __m128 rho8v; rho8v = _mm_setzero_ps(); + __m128 rho9v; rho9v = _mm_setzero_ps(); + __m128 rho10v; rho10v = _mm_setzero_ps(); + __m128 rho11v; rho11v = _mm_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m128 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m128 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m128 no_conju = _mm_setr_ps(-1, 1, -1, 1); + __m128 conju = _mm_setr_ps(1, -1, 1, -1); + + for ( ; (i+1) < m; i+=2) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm_loadu_ps( (float *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + + a0v = _mm_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm_fmadd_ps( a5v, x1v, rho11v ); + } + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v = _mm_permute_ps(rho6v, 0xb1); + rho7v = _mm_permute_ps(rho7v, 0xb1); + rho8v = _mm_permute_ps(rho8v, 0xb1); + rho9v = _mm_permute_ps(rho9v, 0xb1); + rho10v = _mm_permute_ps(rho10v, 0xb1); + rho11v = _mm_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + + rho6v = _mm_mul_ps(rho6v, no_conju); + rho7v = _mm_mul_ps(rho7v, no_conju); + rho8v = _mm_mul_ps(rho8v, no_conju); + rho9v = _mm_mul_ps(rho9v, no_conju); + rho10v = _mm_mul_ps(rho10v, no_conju); + rho11v = _mm_mul_ps(rho11v, no_conju); + } + else + { + rho6v = _mm_mul_ps(rho6v, conju); + rho7v = _mm_mul_ps(rho7v, conju); + rho8v = _mm_mul_ps(rho8v, conju); + rho9v = _mm_mul_ps(rho9v, conju); + rho10v = _mm_mul_ps(rho10v, conju); + rho11v = _mm_mul_ps(rho11v, conju); + } + + rho0v = _mm_add_ps(rho0v, rho6v); + rho1v = _mm_add_ps(rho1v, rho7v); + rho2v = _mm_add_ps(rho2v, rho8v); + rho3v = _mm_add_ps(rho3v, rho9v); + rho4v = _mm_add_ps(rho4v, rho10v); + rho5v = _mm_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 2; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 2; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 2; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 2; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 2; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 2; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + } + /*handles remainder cases*/ + if(rem) + { + if ( bli_is_noconj( conjx_use ) ) + { + + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpys)( a[i + p*lda], x[i], r[p] ); + } + } + else + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpyjs)( a[i + p*lda], x[i], r[p] ); + } + + } + } + + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,conjs)( r[i] ); + } + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + scomplex* restrict a1 = a + (0 )*inca + (i )*lda; + scomplex* restrict x1 = x + (0 )*incx; + scomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } } } diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index 480837c023..d4ad0143ed 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -1,8 +1,10 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index b3c92b551c..74904605ee 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -35,6 +35,24 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + + /* This implementation uses 512 bits of cache line efficiently for column stored matrix and vectors. @@ -477,3 +495,380 @@ void bli_cgemv_zen_int_4x4 } } + +/* +Function performs multithreaded GEMV for float datatype +All parameters are similar to single thread GEMV except +n_thread which specifies the number of threads to be used +*/ +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ) +{ + const dim_t b_fuse = 4; + const dim_t n_elem_per_reg = 8; + dim_t total_iteration = 0; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(s, eq0)(*alpha)) + { + + bli_sscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n < b_fuse) + { + for (dim_t i = 0; i < b_n; ++i) + { + float *a1 = a + (0) * inca + (i)*lda; + float *x1 = x + (0) * incx; + float *psi1 = y + (i)*incy; + + bli_sdotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // Calculate the total number of multithreaded iteration + total_iteration = b_n / b_fuse; + +#pragma omp parallel for num_threads(n_threads) + for (dim_t j = 0; j < total_iteration; j++) + { + float *A1 = a + (b_fuse * j) * lda; + float *x1 = x; + float *y1 = y + (b_fuse * j) * incy; + + // Intermediate variables to hold the completed dot products + float rho0[4] = {0, 0, 0, 0}; + + // If vectorization is possible, perform them with vector + // instructions. + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll = 2; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t l, unroll_inc, m_viter[2], m_left = m; + + unroll_inc = n_elem_per_reg * n_iter_unroll; + + m_viter[0] = m_left / unroll_inc; + m_left = m_left % unroll_inc; + + m_viter[1] = m_left / n_elem_per_reg ; + m_left = m_left % n_elem_per_reg; + + // Set up pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict av[4]; + + av[0] = A1 + 0 * lda; + av[1] = A1 + 1 * lda; + av[2] = A1 + 2 * lda; + av[3] = A1 + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v8sf_t rhov[4]; + + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + v8sf_t xv[2]; + v8sf_t a_vec[8]; + + // FMA operation is broken down to mul and add + // to reduce backend stalls + for (l = 0; l < m_viter[0]; ++l) + { + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + xv[1].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[4].v = _mm256_loadu_ps(av[0] + n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + a_vec[0].v = _mm256_mul_ps(a_vec[0].v, xv[0].v); + rhov[0].v = _mm256_fmadd_ps(a_vec[4].v, xv[1].v, rhov[0].v); + rhov[0].v = _mm256_add_ps(a_vec[0].v, rhov[0].v); + + a_vec[1].v = _mm256_loadu_ps(av[1]); + a_vec[5].v = _mm256_loadu_ps(av[1] + n_elem_per_reg); + + a_vec[1].v = _mm256_mul_ps(a_vec[1].v, xv[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[5].v, xv[1].v, rhov[1].v); + rhov[1].v = _mm256_add_ps(a_vec[1].v, rhov[1].v); + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[6].v = _mm256_loadu_ps(av[2] + n_elem_per_reg); + + a_vec[2].v = _mm256_mul_ps(a_vec[2].v, xv[0].v); + rhov[2].v = _mm256_fmadd_ps(a_vec[6].v, xv[1].v, rhov[2].v); + rhov[2].v = _mm256_add_ps(a_vec[2].v, rhov[2].v); + + a_vec[3].v = _mm256_loadu_ps(av[3]); + a_vec[7].v = _mm256_loadu_ps(av[3] + n_elem_per_reg); + + a_vec[3].v = _mm256_mul_ps(a_vec[3].v, xv[0].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[7].v, xv[1].v, rhov[3].v); + rhov[3].v = _mm256_add_ps(a_vec[3].v, rhov[3].v); + + av[0] += unroll_inc; + av[1] += unroll_inc; + av[2] += unroll_inc; + av[3] += unroll_inc; + } + + for (l = 0; l < m_viter[1]; ++l) + { + // Load the input values. + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[1].v = _mm256_loadu_ps(av[1]); + + rhov[0].v = _mm256_fmadd_ps(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[1].v, xv[0].v, rhov[1].v); + + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[3].v = _mm256_loadu_ps(av[3]); + + rhov[2].v = _mm256_fmadd_ps(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[3].v, xv[0].v, rhov[3].v); + + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + } + + // Sum the elements within each vector. + // Sum the elements of a given rho?v with hadd. + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[1].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[3].v); + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[0].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[2].v); + + // Manually add the results from above to finish the sum. + rho0[0] = rhov[0].f[0] + rhov[0].f[4]; + rho0[1] = rhov[0].f[1] + rhov[0].f[5]; + rho0[2] = rhov[2].f[0] + rhov[2].f[4]; + rho0[3] = rhov[2].f[1] + rhov[2].f[5]; + + // If leftover elements are more than 4, perform SSE + if (m_left > 4) + { + v4sf_t xv128, a_vec128[4], rhov128[4]; + + rhov128[0].v = _mm_set1_ps(0); + rhov128[1].v = _mm_set1_ps(0); + rhov128[2].v = _mm_set1_ps(0); + rhov128[3].v = _mm_set1_ps(0); + + // Load the input values. + xv128.v = _mm_loadu_ps(x0 + 0 * n_elem_per_reg); + x0 += 4; + m_left -= 4; + + a_vec128[0].v = _mm_loadu_ps(av[0]); + a_vec128[1].v = _mm_loadu_ps(av[1]); + + // perform: rho?v += a?v * x0v; + rhov128[0].v = _mm_fmadd_ps(a_vec128[0].v, xv128.v, rhov128[0].v); + rhov128[1].v = _mm_fmadd_ps(a_vec128[1].v, xv128.v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[0].v); + + a_vec128[2].v = _mm_loadu_ps(av[2]); + a_vec128[3].v = _mm_loadu_ps(av[3]); + + rhov128[2].v = _mm_fmadd_ps(a_vec128[2].v, xv128.v, rhov128[2].v); + rhov128[3].v = _mm_fmadd_ps(a_vec128[3].v, xv128.v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[2].v); + + rho0[0] += rhov128[0].f[0]; + rho0[1] += rhov128[0].f[1]; + rho0[2] += rhov128[2].f[0]; + rho0[3] += rhov128[2].f[1]; + + av[0] += 4; + av[1] += 4; + av[2] += 4; + av[3] += 4; + } + + // If there are leftover iterations, perform them with scalar code. + for (l = 0; l < m_left; ++l) + { + rho0[0] += *(av[0]) * (*x0); + rho0[1] += *(av[1]) * (*x0); + rho0[2] += *(av[2]) * (*x0); + rho0[3] += *(av[3]) * (*x0); + + x0 += incx; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + } + + } + else + { + // When vectorization is not possible, perform with scalar code + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict a0 = A1 + 0 * lda; + float *restrict a1 = A1 + 1 * lda; + float *restrict a2 = A1 + 2 * lda; + float *restrict a3 = A1 + 3 * lda; + + for (dim_t l = 0; l < m; ++l) + { + const float x0c = *x0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + + rho0[0] += a0c * x0c; + rho0[1] += a1c * x0c; + rho0[2] += a2c * x0c; + rho0[3] += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + } + + v4sf_t rho0v, y0v; + + rho0v.v = _mm_loadu_ps(rho0); + + // Broadcast the alpha scalar. + v4sf_t alphav; + alphav.v = _mm_broadcast_ss(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(s, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_ps(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v4sf_t betav; + betav.v = _mm_broadcast_ss(beta); + + if (incy == 0) + { + // Load y. + y0v.v = _mm_loadu_ps(y1 + 0 * n_elem_per_reg); + } + else + { + // Load y. + y0v.f[0] = *(y1 + 0 * incy); + y0v.f[1] = *(y1 + 1 * incy); + y0v.f[2] = *(y1 + 2 * incy); + y0v.f[3] = *(y1 + 3 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_ps(betav.v, y0v.v); + y0v.v = _mm_fmadd_ps(alphav.v, rho0v.v, y0v.v); + } + + // Store the output. + if (incy == 1) + { + _mm_storeu_ps((y1 + 0 * n_elem_per_reg), y0v.v); + } + else + { + // Store the output. + *(y1 + 0 * incy) = y0v.f[0]; + *(y1 + 1 * incy) = y0v.f[1]; + *(y1 + 2 * incy) = y0v.f[2]; + *(y1 + 3 * incy) = y0v.f[3]; + } + } + + // Performs the complete computation if OpenMP is not enabled + dim_t start = total_iteration * b_fuse; + dim_t new_fuse = 8, f; + + // Left over corner cases completed using fused kernel + for (dim_t i = start; i < b_n; i += f) + { + f = bli_determine_blocksize_dim_f(i, b_n, new_fuse); + + float *A1 = a + (i)*lda + (0) * inca; + float *x1 = x + (0) * incx; + float *y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8( + conjat, + conjx, + m, + f, + alpha, + A1, inca, lda, + x1, incx, + beta, + y1, incy, + cntx); + } +} diff --git a/kernels/zen/2/bli_her2_zen_int_4.c b/kernels/zen/2/bli_her2_zen_int_4.c new file mode 100644 index 0000000000..9b181aa278 --- /dev/null +++ b/kernels/zen/2/bli_her2_zen_int_4.c @@ -0,0 +1,396 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 0; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + m + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x + m); + y0 = _mm256_loadu_pd(y + m); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) + m); + a1 = _mm256_loadu_pd(a + (1 * lda) + m); + a2 = _mm256_loadu_pd(a + (2 * lda) + m); + a3 = _mm256_loadu_pd(a + (3 * lda) + m); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda) + m, a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) + m ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda) + m, a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) + m ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda) + m, a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + m + i + (i * lda)) = a_diag[i]; + } +} + + +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 4; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x); + y0 = _mm256_loadu_pd(y); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) ); + a1 = _mm256_loadu_pd(a + (1 * lda) ); + a2 = _mm256_loadu_pd(a + (2 * lda) ); + a3 = _mm256_loadu_pd(a + (3 * lda) ); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda), a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda), a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda), a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + i + (i * lda)) = a_diag[i]; + } +} diff --git a/kernels/zen/3/bli_dgemm_ref_k1.c b/kernels/zen/3/bli_dgemm_ref_k1.c index 659975cdb7..03a2b789bb 100644 --- a/kernels/zen/3/bli_dgemm_ref_k1.c +++ b/kernels/zen/3/bli_dgemm_ref_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -394,6 +394,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -690,6 +691,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -897,6 +899,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); @@ -1052,6 +1055,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index d9c4047ec4..0cf5c8c5ce 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,7 @@ #define MR 32 #define D_MR (MR >> 1) +#define Z_MR (MR >> 3) #define NR 3 #define D_BLIS_SMALL_MATRIX_K_THRES_ROME 256 @@ -47,7 +48,7 @@ #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) -#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. +#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR). static err_t bli_sgemm_small ( @@ -70,7 +71,26 @@ err_t bli_dgemm_small cntx_t* cntx, cntl_t* cntl ); - +err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); static err_t bli_sgemm_small_atbn ( obj_t* alpha, @@ -108,22 +128,18 @@ err_t bli_gemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + #ifdef BLIS_ENABLE_MULTITHREADING - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; #else - // When dynamic dispatch is enabled i.e. library is built for 'amdzen' configuration. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (0 == bamdzen) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif // If alpha is zero, scale by beta and return. @@ -156,6 +172,18 @@ err_t bli_gemm_small return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif } + if(dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small_At is called directly from blas interface for + // sizes within thresholds. + // Avoinding calling of bli_zgemm_small_At from gemm_front + // and directing to native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); +#endif + } if (bli_obj_has_notrans( b )) { @@ -184,16 +212,28 @@ err_t bli_gemm_small #endif } + if (dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small is called directly from BLAS interface for sizes within thresholds. + // Avoiding calling bli_zgemm_small from gemm_front and directing to + // native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small(alpha, a, b, beta, c, cntx, cntl); +#endif + } + + if (dt == BLIS_FLOAT) { return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; - static err_t bli_sgemm_small ( obj_t* alpha, @@ -205,13 +245,13 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; - // when N is equal to 1 call GEMV instead of GEMM + // when N is equal to 1 call GEMV instead of GEMM if (N == 1) { bli_gemv @@ -222,7 +262,7 @@ static err_t bli_sgemm_small beta, c ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } @@ -248,7 +288,7 @@ static err_t bli_sgemm_small dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B - __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10, ymm11; __m256 ymm12, ymm13, ymm14, ymm15; __m256 ymm0, ymm1, ymm2, ymm3; @@ -262,7 +302,7 @@ static err_t bli_sgemm_small const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -270,7 +310,7 @@ static err_t bli_sgemm_small is_beta_non_zero = 1; } - //update the pointer math if matrix B needs to be transposed. + //update the pointer math if matrix B needs to be transposed. if (bli_obj_has_trans( b )) { tb_inc_col = 1; //switch row and column strides tb_inc_row = ldb; @@ -299,11 +339,11 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initialization + // We will use the same size to avoid pool re-initialization siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - // Based on the available memory in the buffer we will decide if + // Based on the available memory in the buffer we will decide if // we want to do packing or not. // // This kernel assumes that "A" will be un-packged if N <= 3. @@ -315,18 +355,18 @@ static err_t bli_sgemm_small // If this check is removed it will result in the crash as // reported in CPUPL-587. // - + if ((N <= 3) || (((MR * K) << 2) > buffer_size)) { required_packing_A = 0; } - else + else { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool, if there is no pool with - // required size, it will be created. + // required size, it will be created. bli_membrk_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, @@ -1628,7 +1668,7 @@ static err_t bli_sgemm_small if(is_beta_non_zero){ ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); } - _mm256_storeu_ps(f_temp, ymm7); + _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { tC[i] = f_temp[i]; @@ -1730,18 +1770,18 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; @@ -1756,22 +1796,25 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; /* if (N<3) //Implemenation assumes that N is atleast 3. VK */ - /* { */ - /* AOCL_DTL_TRACE_EXIT_ERR( */ - /* AOCL_DTL_LEVEL_INFO, */ + /* { */ + /* AOCL_DTL_TRACE_EXIT_ERR( */ + /* AOCL_DTL_LEVEL_INFO, */ /* "N < 3 cannot be processed by small_gemm" */ - /* ); */ + /* ); */ /* return BLIS_NOT_YET_IMPLEMENTED; VK */ - /* } */ - + /* } */ + if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels { @@ -1844,7 +1887,7 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); @@ -1860,12 +1903,12 @@ static err_t bli_sgemm_small // reported in CPUPL-587. // - // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) - if ((N < 3) || ((D_MR * K) << 3) > buffer_size) + // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) + if ((N < 3) || ((D_MR * K) << 3) > buffer_size) { required_packing_A = 0; } - + if (required_packing_A == 1) { #ifdef BLIS_ENABLE_MEM_TRACING @@ -2869,7 +2912,6 @@ static err_t bli_sgemm_small if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) @@ -3320,17 +3362,17 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } }; static err_t bli_sgemm_small_atbn @@ -3344,9 +3386,9 @@ static err_t bli_sgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -3371,7 +3413,7 @@ static err_t bli_sgemm_small_atbn float scratch[8] = {0.0}; const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -3797,17 +3839,17 @@ static err_t bli_sgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } } static err_t bli_dgemm_small_atbn @@ -3821,8 +3863,8 @@ static err_t bli_dgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -4237,17 +4279,17 @@ static err_t bli_dgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } } err_t bli_dgemm_small_At @@ -4263,7 +4305,10 @@ err_t bli_dgemm_small_At { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . @@ -4313,14 +4358,14 @@ err_t bli_dgemm_small_At if( bli_obj_has_trans( b ) ) { - tb_inc_col = 1; // switch row and column strides + tb_inc_col = 1; // switch row and column strides tb_inc_row = ldb; } __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm0, ymm1, ymm2, ymm3; double result; double scratch[8] = {0.0}; @@ -4358,7 +4403,7 @@ err_t bli_dgemm_small_At bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); @@ -5381,7 +5426,6 @@ err_t bli_dgemm_small_At if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; tA = A + row_idx * lda; @@ -5709,5 +5753,7684 @@ err_t bli_dgemm_small_At return BLIS_NONCONFORMAL_DIMENSIONS; } }; + + +#define BLIS_SET_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + ymm17 = _mm256_setzero_pd(); \ + ymm18 = _mm256_setzero_pd(); \ + ymm19 = _mm256_setzero_pd(); \ + ymm20 = _mm256_setzero_pd(); \ + ymm21 = _mm256_setzero_pd(); \ + + +#define BLIS_SET_ALL_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + + + +err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + // number of columns of OP(A), will be updated if OP(A) is Transpose(A) + gint_t K = bli_obj_width( a ); + gint_t L = M * N; + + if(L && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 4.(M%4) + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. However, using this global array make the function + * non-reentrant. Instead of using a global array we should allocate + * buffer for each invocation. Since the buffer size is too big or stack + * and doing malloc every time will be too expensive, better approach is + * to get the buffer from the pre-allocated pool and it the pool once we + * are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive the memory broker (via rntm). Following hack will get the + * global memory broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for Z_MRxN columns of C matrix, thus + * accessing the Z_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension Z_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + /** + * This is the part of the pack and compute optimization. + * During the first column iteration, we store the accessed A + * matrix into contiguous static memory. This helps to keep te A + * matrix in Cache and aviods the TLB misses. + */ + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); #endif + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B + // matrix i data and multiplies it with + // the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *) + (tA_packed + 2), ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) * + 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + // col 2 + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; + col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts + // the B matrix data and multiplies it + // with the A matrix. This loop is + // processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts the + // B matrix data and multiplies it with + // the A matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + m_remainder = M - row_idx; + + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each + * invocation. + * Since the buffer size is too big or stack and doing malloc every time + * will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and + * return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = Z_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 2 * lda; + tA_packed = D_A_pack + (x + 1)*2; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + + m_remainder = M - row_idx; + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 3; + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + ymm3 = _mm256_loadu_pd((double const *) + (tA_temp + 2 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + xmm0 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed + 2), + xmm0); + + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + xmm0 = _mm256_extractf128_pd(ymm3, 1); + _mm_storeu_pd((double *) + (tA_packed + 1 * lda_packed + 2), + xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + tA_packed[2].real = tA_temp[2 * lda].real; + tA_packed[2].imag = tA_temp[2 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 3; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 2; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 1; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + + xmm0 = _mm256_extractf128_pd(ymm0, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + xmm0); + + xmm0 = _mm256_extractf128_pd(ymm0, 1); + _mm_storeu_pd((double *)(tA_packed + 1 + * lda_packed), xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 1; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; +#endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index c782a08a49..bb8a2e9cc5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2847,6 +2847,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3009,6 +3010,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3116,6 +3118,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); @@ -3818,15 +3821,22 @@ err_t bli_trsm_small num_t dt = bli_obj_dt(a); switch(dt) { - case BLIS_DOUBLE: - case BLIS_FLOAT: - case BLIS_SCOMPLEX: - { - if(m > 1000 || n > 1000) { + case BLIS_DOUBLE: + { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + { + if(m > 1000 || n > 1000) { return BLIS_NOT_YET_IMPLEMENTED; } break; - } + } case BLIS_DCOMPLEX: { if(m > 500 || n > 500) { @@ -3883,38 +3893,145 @@ err_t bli_trsm_small return err; }; +#ifdef BLIS_ENABLE_OPENMP +/* + * Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right) + */ + +err_t bli_trsm_small_mt +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + gint_t m = bli_obj_length( b ); // number of rows of matrix b + gint_t n = bli_obj_width( b ); // number of columns of Matrix b + dim_t d_mr = 8,d_nr = 6; + + num_t dt = bli_obj_dt(a); + switch(dt) + { + case BLIS_DOUBLE: + { + d_mr = 8,d_nr = 6; + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); + } + #endif + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + + if (n_threads < 0 ) n_threads = 1; + + err_t status = BLIS_SUCCESS; + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + + // Compute start and end indexes of matrix partitioning for each thread + if ( bli_is_right( side ) ) + { + bli_thread_range_sub ( &thread, + m, + d_mr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + else + { + bli_thread_range_sub ( &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small + ( + side, + alpha, + a, + &b_t, + NULL, + NULL + ); + // To capture the error populated from any of the threads + _Pragma( "omp critical" ) + status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status; + } + + return status; +}// End of function +#endif + /* * ZTRSM utilities and kernel functions */ #define DCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - double dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ +/* dcomplex inva = {1.0, 0.0};*/\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_zinvscals(b, a);\ } #define DCOMPLEX_MUL(a, b, c) {\ - double real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - double imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_zscals(a,c);\ } #define DCOMPLEX_DIV(a, b){\ - double dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_zinvscals(b,a); \ } @@ -3943,11 +4060,8 @@ err_t bli_trsm_small #define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ if(!is_unitdiag)\ {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - DCOMPLEX_MUL(c, a, c)\ - DCOMPLEX_DIV(c, b)\ - }\ + bli_zinvscals(b, c);\ + }\ } #endif @@ -4296,6 +4410,213 @@ BLIS_INLINE err_t ztrsm_AuXB_ref _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ } + +#define BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + + /** * Performs GEMM operation. * Two elements of column in ymm0 @@ -5577,68 +5898,58 @@ BLIS_INLINE err_t ztrsm_AuXB_ref * Performs dcomplex division of vec1 and vec2 with ymm1. * vec1 and vec2 gets divided by ymm1 which holds * diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ if(!is_unitdiag) {\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0);\ - /*perform decomplex multiplication*/\ - /* Switch the real and imaginary elements of vec2 */\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - /* Multiply vec1 and vec2 */ \ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - /* Multiply vec1 and the modified vec2 */\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - /* Horizontally subtract the elements in vec3 and vec4 */\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec2, ymm12);\ - ymm14 = _mm256_mul_pd(vec2, ymm14);\ - vec2 = _mm256_hsub_pd(ymm13, ymm14);\ - /*dcomplex multiplication is done*/\ - /*Swapping real & imaginary component position for addition with respective - * components*/\ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ - vec2 = _mm256_div_pd(vec2, ymm14);\ +\ + dcomplex b_data[4];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(b_data + 2), vec2);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ + vec2 = _mm256_loadu_pd((double *)(b_data+2));\ +\ }\ } /** * Performs dcomplex division of vec1 with ymm1. * ymm1 holds diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_DIV(vec1) {\ if(!is_unitdiag){\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ +\ + dcomplex b_data[2];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 2; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ +\ }\ } @@ -5813,7 +6124,6 @@ BLIS_INLINE void bli_ztrsm_small_pack } - BLIS_INLINE void ztrsm_small_pack_diag_element ( bool is_unitdiag, @@ -5824,64 +6134,31 @@ BLIS_INLINE void ztrsm_small_pack_diag_element ) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); -#else - __m256d ymm1, ymm2, ymm3; -#endif - bool is_four = (size == 4) ? 1 : 0; - dcomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - if(!is_unitdiag) + // If Preinversion is enabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_zinvscals is used + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_pd((__m128d const *)a11); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); - /*Pick one element frome each column and create 3 element vector - and store it*/ - ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - - if(is_four) - { - ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); - ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); - } + dim_t d = ((i*cs_a) + i); + dcomplex ones = {1.0, 0.0}; + bli_zinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; + } -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - ymm5 = _mm256_mul_pd(ymm2,ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); - - ymm5 = _mm256_add_pd(ymm5, ymm8); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - ymm2 = _mm256_mul_pd(ymm2, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); - ymm2 = _mm256_div_pd(ymm2, ymm5); -#endif +#else //BLIS_ENABLE_TRSM_PREINVERSION - } - _mm256_store_pd((double *)d11_pack, ymm1); - if(is_four) + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_pd((double *)(d11_pack + 2), ymm2); + dim_t d = ((i*cs_a) + i); + bli_zcopys(a11[d],d11_pack[i]); } - else - { - _mm_store_pd((double *)(d11_pack + 2), - _mm256_extractf128_pd(ymm2,0)); - } +#endif //BLIS_ENABLE_TRSM_PREINVERSION } - /*implements TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal/unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn @@ -5955,7 +6232,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -6450,25 +6727,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -= 3; i += 3; @@ -6580,25 +6851,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 2; i += 2; @@ -6710,25 +6968,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 1; i += 1; @@ -7120,23 +7365,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -= 3; i += 3; @@ -7217,21 +7454,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -= 2; i += 2; @@ -7311,15 +7537,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); @@ -8415,7 +8632,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -8888,25 +9105,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -=3; } @@ -9009,25 +9220,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=2; } @@ -9130,25 +9328,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=1; } @@ -9529,23 +9714,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -=3; } @@ -9621,21 +9798,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -=2; } @@ -9708,15 +9874,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); @@ -10759,7 +10917,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B //pointers that point to blocks for GEMM and TRSM double *a10, *a11, *b01, *b11; @@ -12739,7 +12897,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM @@ -14754,9 +14912,12 @@ BLIS_INLINE void strsm_small_pack_diag_element __m256 ymm0, ymm1, ymm2, ymm3; __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10,ymm11; - __m256 ymm14, ymm15, ymm12,ymm13; + __m256 ymm14, ymm15, ymm12; float ones = 1.0; - ymm13 = ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); + ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + __m256 ymm13 = _mm256_broadcast_ss((float const *)&ones); +#endif if(side=='L'||side=='l') { if(!is_unitdiag) @@ -31940,75 +32101,160 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(m_rem == 3) { dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); + if(transa) + { + dim_t x = 0; + for(x = 0; (x+3) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+2) < i; x += 3) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2 + cs_a)); + ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + xmm4 = _mm_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + + a10 += 3; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+1) < i; x += 2) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } + a10 += 2; + ptr_a10_dup += p_lda * p_lda; + } + for(; x < i; x += 1) + { + xmm4 = _mm_loadu_pd((double const *)(a10)); + xmm5 = _mm_loadu_pd((double const *) + (a10 + cs_a)); - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { @@ -33426,37 +33758,38 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB } else if(m_remainder == 1) { - dim_t p_lda = 2; // packed leading dimension - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { @@ -34470,38 +34803,21 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB { if(transa) { - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -34888,30 +35204,23 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -35922,39 +36231,20 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { @@ -36341,30 +36631,22 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm1); + } for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { @@ -36527,33 +36809,19 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB */ #define SCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - float dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_cinvscals(b, a);\ } #define SCOMPLEX_MUL(a, b, c) {\ - float real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - float imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_cscals(a,c);\ } #define SCOMPLEX_DIV(a, b){\ - float dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_cinvscals(b,a); \ } #ifdef BLIS_ENABLE_TRSM_PREINVERSION @@ -36579,13 +36847,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define CTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ - if(!is_unitdiag)\ - {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - SCOMPLEX_MUL(c, a, c)\ - SCOMPLEX_DIV(c, b)\ - }\ + if(!is_unitdiag)\ + {\ + bli_cinvscals(b, c);\ + }\ } #endif @@ -36981,72 +37246,30 @@ BLIS_INLINE void ctrsm_small_pack_diag_element dim_t size ) { - __m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm8; - bool is_eight = (size == 8) ? 1 : 0; - scomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256 ymm7; - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); -#endif - - if(!is_unitdiag) + // If Preinversion is disabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_cinvscals is used. + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_ps((__m128 const *)a11); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a +1)); - ymm3 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a*2 +2)); - - ymm1 = _mm256_shuffle_ps(ymm1, ymm2, 0x44); - - if(is_eight) { - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 4 + cs_a*4)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 5 + cs_a*5)); - ymm6 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 6 + cs_a*6)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 7 + cs_a*7)); - ymm8 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm2 = _mm256_blend_ps(ymm6, ymm8, 0xF0); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 3 + cs_a*3)); - ymm3 = _mm256_shuffle_ps(ymm3, ymm4, 0x44); - } - - ymm1 = _mm256_blend_ps(ymm1, ymm3, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm5 = _mm256_mul_ps(ymm2, ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - //BEFORE - //a[0] a[1] a[2] a[3] - //AFTER - //a[1] a[0] a[3] a[2] - //MESS - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm8 = _mm256_permute_ps(ymm5, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm5 = _mm256_add_ps(ymm5, ymm8); - - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm2 = _mm256_mul_ps(ymm2, ymm7); - - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_ps(ymm1, ymm4); - ymm2 = _mm256_div_ps(ymm2, ymm5); - -#endif + dim_t d = ((i*cs_a) + i); + scomplex ones = {1.0, 0.0}; + bli_cinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; } - _mm256_store_ps((float *)d11_pack, ymm1); - if(is_eight) + +#else //BLIS_ENABLE_TRSM_PREINVERSION + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_ps((float *)(d11_pack + 4), ymm2); + dim_t d = ((i*cs_a) + i); + bli_ccopys(a11[d],d11_pack[i]); } + +#endif //BLIS_ENABLE_TRSM_PREINVERSION } /** @@ -37294,26 +37517,19 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ + scomplex b_data[4];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + \ }\ } @@ -37324,32 +37540,21 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_shuffle_ps(vec2, vec2, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec2, vec2,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec2 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*done*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ - vec2 = _mm256_div_ps(vec2, ymm3);\ + scomplex b_data[8];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(b_data + 4), vec2);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 8; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + vec2 = _mm256_loadu_ps((float *)(b_data+4));\ + \ }\ } @@ -39983,43 +40188,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = 0; (j+d_nr-1) < n; j += d_nr) { @@ -42230,43 +42405,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) { @@ -43822,30 +43967,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44301,25 +44429,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44574,7 +44687,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -45333,37 +45445,17 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { @@ -45828,25 +45920,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index a21c9b5ed1..77f0348561 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -138,6 +138,8 @@ void bli_cgemmsup_rv_zen_asm_3x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -572,6 +574,8 @@ void bli_cgemmsup_rv_zen_asm_2x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -919,6 +923,8 @@ void bli_cgemmsup_rv_zen_asm_1x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 73104f817d..f2edd993ce 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -32,7 +32,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ - // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) PACKM_KER_PROT(double, d, packm_6xk_gen_zen) @@ -46,6 +45,16 @@ PACKM_KER_PROT(double, d, packm_6xk_nn_zen) AMAXV_KER_PROT( float, s, amaxv_zen_int ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) +// axpbyv (intrinsics) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int ) +AXPBYV_KER_PROT( scomplex, c, axpbyv_zen_int ) +AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) + +// axpbyv (intrinsics, unrolled x10) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) + // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) @@ -69,6 +78,8 @@ DOTV_KER_PROT( dcomplex, z, dotv_zen_int5 ) // dotxv (intrinsics) DOTXV_KER_PROT( float, s, dotxv_zen_int ) DOTXV_KER_PROT( double, d, dotxv_zen_int ) +DOTXV_KER_PROT( dcomplex, z, dotxv_zen_int ) +DOTXV_KER_PROT( scomplex, c, dotxv_zen_int ) // scalv (intrinsics) SCALV_KER_PROT( float, s, scalv_zen_int ) @@ -104,10 +115,21 @@ AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_5 ) AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_4 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) +// axpy2v (intrinsics) +AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) +AXPY2V_KER_PROT(dcomplex, z, axpy2v_zen_int ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_6 ) +DOTXF_KER_PROT( scomplex, c, dotxf_zen_int_6 ) +// dotxaxpyf (intrinsics) +DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( scomplex, c, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( dcomplex, z, dotxaxpyf_zen_int_8 ) // -- level-2 ---------------------------------------------------------------- @@ -241,6 +263,28 @@ err_t bli_dgemm_small_At cntl_t* cntl ); +err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + // gemm square matrix size friendly implementation err_t bli_gemm_sqp ( @@ -265,7 +309,7 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); - err_t bli_trsm_small +err_t bli_trsm_small ( side_t side, obj_t* alpha, @@ -275,6 +319,34 @@ void bli_dgemm_ref_k1_nn cntl_t* cntl ); +#ifdef BLIS_ENABLE_OPENMP +err_t bli_trsm_small_mt + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ); + +#endif + // threshold functions bool bli_cntx_gemmtsup_thresh_is_met_zen ( @@ -301,3 +373,4 @@ void bli_dnorm2fv_unb_var1 cntx_t* cntx ); #endif + diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 1b5fc86998..2786f00e43 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -37,16 +37,31 @@ // -- gemmt specific function bool bli_cntx_gemmtsup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) { - num_t dt = bli_obj_dt( c ); + num_t dt = bli_obj_dt( c ); + dim_t n = bli_obj_length( c ); + dim_t k = bli_obj_width_after_trans( a ); + rntm_t rntm; - dim_t n = bli_obj_length( c ); - dim_t k = bli_obj_width_after_trans( a ); + bli_rntm_init_from_global( &rntm ); + + // Query the number of threads from rntm object. + const dim_t n_threads = bli_rntm_num_threads( &rntm ); if( bli_is_double( dt )) { - if ( n < 300 ) return TRUE; - if ( (k / n ) > 50 ) return TRUE; - + if( n_threads == 16) + { + /*Push sizes for n<1200 into SUP path*/ + if ( n < 1200 ) return TRUE; + /*For 12005 , With packing , Native path performance is better */ + if ( n < 1600 && (n / k) < 5) return TRUE; + } + else + { + if ( n < 800 ) return TRUE; + if ( (k / n ) > 50 ) return TRUE; + } return FALSE; } else if ( bli_is_dcomplex( dt ) ) diff --git a/so_version b/so_version index a831c0e579..8efd5969fe 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.0 +2.0 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3b0315c9ae..d116e942d0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,172 +1,172 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestAminv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestAminv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestCopyv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestDotv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestDotv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGemm "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGemmt "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGemv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestGer "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestGer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHemm "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHemv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHer "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHer2 "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHer2k "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestHerk "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestHerk OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestScalv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestScalv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestSwapv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestTrmm "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestTrmv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestTrsm "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(TestTrsv "${OPENMP_PATH}/libomp.lib") + target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/test/test_gemm.c b/test/test_gemm.c index 772d73c7b1..81b7e36616 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -47,428 +47,471 @@ // uncomment to enable cblas interface //#define CBLAS -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { - obj_t a, b, c; - obj_t c_save; - obj_t alpha, beta; - dim_t m, n, k; - inc_t lda, ldb, ldc; - num_t dt, dt_a; - inc_t r, n_repeats; - trans_t transa; - trans_t transb; - f77_char f77_transa; - f77_char f77_transb; - - double dtime; - double dtime_save; - double gflops; - - //bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); - - n_repeats = 300; - - //dt = BLIS_FLOAT; - dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; - - if( bli_is_real( dt ) || bli_is_scomplex( dt ) ) + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + inc_t lda, ldb, ldc; + num_t dt, dt_a; + inc_t r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 300; + + // dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; + + if (bli_is_real(dt) || bli_is_scomplex(dt)) dt_a = dt; else { dt_a = dt; // Enable the following to call - // dzgemm - //dt_a = BLIS_DOUBLE; + // dzgemm + // dt_a = BLIS_DOUBLE; } const char stor_scheme = 'C'; - transa = BLIS_NO_TRANSPOSE; - transb = BLIS_NO_TRANSPOSE; - - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); printf("BLIS Library version is : %s\n", bli_info_get_version_str()); #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; - if (argc < 3){ - printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); - exit(1); - } - fin = fopen(argv[1], "r"); - if (fin == NULL){ - printf("Error opening the file %s\n", argv[1]); - exit(1); - } - fout = fopen(argv[2], "w"); - if (fout == NULL){ - printf("Error opening output file %s\n", argv[2]); - exit(1); - } - fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); - printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); - - while (fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) - { - // dimensions should not be greater than leading dimensions - // These are valid only when Op(A) = n and op(B) = n - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - if ((m > lda) || (k > ldb) || (m > ldc)) continue; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - // leading dimension should be greater than number of cols - if ((k > lda) || (n > ldb) || (n > ldc)) continue; - }else { - printf("Invalid Storage type\n"); - continue; - } + FILE *fin = NULL; + FILE *fout = NULL; + if (argc < 3) + { + printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); + printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); + + while (fscanf(fin, "%ld %ld %ld %ld %ld %ld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) + { + // dimensions should not be greater than leading dimensions + // These are valid only when Op(A) = n and op(B) = n + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + if ((m > lda) || (k > ldb) || (m > ldc)) + continue; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // leading dimension should be greater than number of cols + if ((k > lda) || (n > ldb) || (n > ldc)) + continue; + } + else + { + printf("Invalid Storage type\n"); + continue; + } #else - dim_t p, p_begin, p_end, p_inc; - dim_t m_input, n_input, k_input; - p_begin = 200; - p_end = 2000; - p_inc = 200; - - m_input = n_input = k_input = -1; - for ( p = p_begin; p <= p_end; p += p_inc ) - { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; - if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); - else k = ( dim_t ) k_input; - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - lda = m; ldb = k, ldc = m; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - lda = k; ldb = n, ldc = n; - } + dim_t p, p_begin, p_end, p_inc; + dim_t m_input, n_input, k_input; + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = n_input = k_input = -1; + for (p = p_begin; p <= p_end; p += p_inc) + { + if (m_input < 0) + m = p * (dim_t)labs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)labs(n_input); + else + n = (dim_t)n_input; + if (k_input < 0) + k = p * (dim_t)labs(k_input); + else + k = (dim_t)k_input; + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = m; + ldb = k, ldc = m; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + lda = k; + ldb = n, ldc = n; + } #endif - bli_obj_create( dt, 1, 1, 0, 0, &alpha); - bli_obj_create( dt, 1, 1, 0, 0, &beta ); - - siz_t elem_size = bli_dt_size( dt ); - - lda = bli_align_dim_to_size( lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldb = bli_align_dim_to_size( ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldc = bli_align_dim_to_size( ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - - // Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = lda*sizeof(dt_a); - - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - lda += BLIS_SIMD_ALIGN_SIZE/sizeof(dt_a); - - n_bytes = ldb*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldb += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - n_bytes = ldc*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldc += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) - { - // Col-major Order - bli_obj_create( dt_a, m, k, 1, lda, &a ); - bli_obj_create( dt, k, n, 1, ldb, &b ); - bli_obj_create( dt, m, n, 1, ldc, &c ); - bli_obj_create( dt, m, n, 1, ldc, &c_save ); - } - else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) - { - // Row-major Order - bli_obj_create( dt_a, m, k, lda, 1, &a ); - bli_obj_create( dt, k, n, ldb, 1, &b ); - bli_obj_create( dt, m, n, ldc, 1, &c ); - bli_obj_create( dt, m, n, ldc, 1, &c_save ); - } - else - { - printf("Invalid Storage type\n"); - continue; - } + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_obj_create(dt, 1, 1, 0, 0, &beta); + + siz_t elem_size = bli_dt_size(dt); + + lda = bli_align_dim_to_size(lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldb = bli_align_dim_to_size(ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldc = bli_align_dim_to_size(ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = lda * sizeof(dt_a); + + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + lda += BLIS_SIMD_ALIGN_SIZE / sizeof(dt_a); + + n_bytes = ldb * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldb += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + n_bytes = ldc * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldc += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + // Col-major Order + bli_obj_create(dt_a, m, k, 1, lda, &a); + bli_obj_create(dt, k, n, 1, ldb, &b); + bli_obj_create(dt, m, n, 1, ldc, &c); + bli_obj_create(dt, m, n, 1, ldc, &c_save); + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // Row-major Order + bli_obj_create(dt_a, m, k, lda, 1, &a); + bli_obj_create(dt, k, n, ldb, 1, &b); + bli_obj_create(dt, m, n, ldc, 1, &c); + bli_obj_create(dt, m, n, ldc, 1, &c_save); + } + else + { + printf("Invalid Storage type\n"); + continue; + } #ifdef MATRIX_INITIALISATION - bli_randm( &a ); - bli_randm( &b ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&b); + bli_randm(&c); #endif - bli_obj_set_conjtrans( transa, &a); - bli_obj_set_conjtrans( transb, &b); - bli_setsc( (0.9/1.0), 0.2, &alpha ); - bli_setsc( -(1.1/1.0), 0.3, &beta ); - - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) - { - bli_copym( &c_save, &c ); - dtime = bli_clock(); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_conjtrans(transb, &b); + bli_setsc((0.9 / 1.0), 0.2, &alpha); + bli_setsc(-(1.1 / 1.0), 0.3, &beta); + + bli_copym(&c, &c_save); + dtime_save = DBL_MAX; + for (r = 0; r < n_repeats; ++r) + { + bli_copym(&c_save, &c); + dtime = bli_clock(); #ifdef BLIS - bli_gemm( &alpha, - &a, - &b, - &beta, - &c ); + bli_gemm(&alpha, + &a, + &b, + &beta, + &c); #else - f77_int lda, ldb, ldc; - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); + f77_int lda, ldb, ldc; + f77_int mm = bli_obj_length(&c); + f77_int kk = bli_obj_width_after_trans(&a); + f77_int nn = bli_obj_width(&c); #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_TRANSPOSE cblas_transb; - - if ( bli_obj_row_stride( &c ) == 1 ){ - cblas_order = CblasColMajor; - }else{ - cblas_order = CblasRowMajor; - } - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if( bli_is_trans( transb ) ) - cblas_transb = CblasTrans; - else if( bli_is_conjtrans( transb ) ) - cblas_transb = CblasConjTrans; - else - cblas_transb = CblasNoTrans; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + if (bli_obj_row_stride(&c) == 1) + { + cblas_order = CblasColMajor; + } + else + { + cblas_order = CblasRowMajor; + } + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_trans(transb)) + cblas_transb = CblasTrans; + else if (bli_is_conjtrans(transb)) + cblas_transb = CblasConjTrans; + else + cblas_transb = CblasNoTrans; #else - f77_char f77_transa; - f77_char f77_transb; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + f77_char f77_transa; + f77_char f77_transb; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); #endif - if( (stor_scheme == 'C') || (stor_scheme == 'c') ){ - lda = bli_obj_col_stride( &a ); - ldb = bli_obj_col_stride( &b ); - ldc = bli_obj_col_stride( &c ); - } else { - lda = bli_obj_row_stride( &a ); - ldb = bli_obj_row_stride( &b ); - ldc = bli_obj_row_stride( &c ); - } - - if ( bli_is_float( dt ) ) - { - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = bli_obj_col_stride(&a); + ldb = bli_obj_col_stride(&b); + ldc = bli_obj_col_stride(&c); + } + else + { + lda = bli_obj_row_stride(&a); + ldb = bli_obj_row_stride(&b); + ldc = bli_obj_row_stride(&c); + } + + if (bli_is_float(dt)) + { + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *bp = bli_obj_buffer(&b); + float *betap = bli_obj_buffer(&beta); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_sgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_sgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - sgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + sgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_double( dt ) ) - { - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *bp = bli_obj_buffer(&b); + double *betap = bli_obj_buffer(&beta); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_dgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - dgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + dgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_scomplex( dt ) ) - { - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *bp = bli_obj_buffer(&b); + scomplex *betap = bli_obj_buffer(&beta); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_cgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_cgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else - cgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + cgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_dcomplex( dt ) ) - { - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *bp = bli_obj_buffer(&b); + dcomplex *betap = bli_obj_buffer(&beta); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_zgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_zgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else -//Disabled dzgemm function temporarily. -#if 0 - if( bli_is_double( dt_a ) ) - { - dzgemm_( - &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - (double*)ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc - ); - } - else - { -#else - zgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); -// } +#if 1 + if (bli_is_double(dt_a)) + { + dzgemm_( + &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + (double *)ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } + else + { + zgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } #endif #endif - } + } #endif #ifdef PRINT - bli_printm( "a", &a, "%4.1f", "" ); - bli_printm( "b", &b, "%4.1f", "" ); - bli_printm( "c", &c, "%4.1f", "" ); - bli_printm( "c after", &c, "%4.1f", "" ); - exit(1); + bli_printm("a", &a, "%4.1f", ""); + bli_printm("b", &b, "%4.1f", ""); + bli_printm("c", &c, "%4.1f", ""); + bli_printm("c after", &c, "%4.1f", ""); + exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); - }//nrepeats + dtime_save = bli_clock_min_diff(dtime_save, dtime); + } // nrepeats - gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); - if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) - gflops *= 2.0; - else if ( bli_is_complex( dt ) ) gflops *= 4.0; + gflops = (2.0 * m * k * n) / (dtime_save * 1.0e9); + if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) + gflops *= 2.0; + else if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf("data_gemm_blis" ); + printf("data_gemm_blis"); #else - printf("data_gemm_%s", BLAS ); + printf("data_gemm_%s", BLAS); #endif - #ifdef FILE_IN_OUT - printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); + printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); - fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); - fflush(fout); + fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); + fflush(fout); #else - printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m,( unsigned long )k, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, (unsigned long)k, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); - bli_obj_free( &beta ); + bli_obj_free(&alpha); + bli_obj_free(&beta); - bli_obj_free( &a ); - bli_obj_free( &b ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); - }//while + bli_obj_free(&a); + bli_obj_free(&b); + bli_obj_free(&c); + bli_obj_free(&c_save); + } // while - //bli_finalize(); + // bli_finalize(); #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - return 0; + return 0; } diff --git a/test/test_trsm.c b/test/test_trsm.c index 72156d92fe..f6709f5d7f 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -50,14 +50,31 @@ #define CACHE_LINE_SIZE 64 -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { obj_t a, c; obj_t c_save; obj_t alpha; dim_t m, n; num_t dt; - int r, n_repeats; + int r, n_repeats; side_t side; uplo_t uploa; trans_t transa; @@ -72,16 +89,20 @@ int main( int argc, char** argv ) double gflops; #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; + FILE *fin = NULL; + FILE *fout = NULL; #else dim_t p; dim_t p_begin, p_end, p_inc; - int m_input, n_input; + int m_input, n_input; - //bli_init(); +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); #ifndef PRINT p_begin = 200; @@ -102,26 +123,26 @@ int main( int argc, char** argv ) n_repeats = 3; - //dt = BLIS_FLOAT; + // dt = BLIS_FLOAT; dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; #ifdef FILE_IN_OUT - if(argc < 3) + if (argc < 3) { printf("Usage: ./test_trsm_XX.x input.csv output.csv\n"); exit(1); } fin = fopen(argv[1], "r"); - if(fin == NULL) + if (fin == NULL) { printf("Error opening the file %s\n", argv[1]); exit(1); } fout = fopen(argv[2], "w"); - if(fout == NULL) + if (fout == NULL) { printf("Error opening the file %s\n", argv[2]); exit(1); @@ -130,425 +151,421 @@ int main( int argc, char** argv ) inc_t cs_b; #ifdef READ_ALL_PARAMS_FROM_FILE char side_c, uploa_c, transa_c, diaga_c; - + fprintf(fout, "side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) + while (fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) { - if( 'l' == side_c|| 'L' == side_c) - side = BLIS_LEFT; - else if('r' == side_c || 'R' == side_c) - side = BLIS_RIGHT; - else - { - printf("Invalid entry for the argument 'side':%c\n",side_c); - continue; - } + if ('l' == side_c || 'L' == side_c) + side = BLIS_LEFT; + else if ('r' == side_c || 'R' == side_c) + side = BLIS_RIGHT; + else + { + printf("Invalid entry for the argument 'side':%c\n", side_c); + continue; + } - if('l' == uploa_c || 'L' == uploa_c) - uploa = BLIS_LOWER; - else if('u' == uploa_c || 'U' == uploa_c) - uploa = BLIS_UPPER; - else - { - printf("Invalid entry for the argument 'uplo':%c\n",uploa_c); - continue; - } + if ('l' == uploa_c || 'L' == uploa_c) + uploa = BLIS_LOWER; + else if ('u' == uploa_c || 'U' == uploa_c) + uploa = BLIS_UPPER; + else + { + printf("Invalid entry for the argument 'uplo':%c\n", uploa_c); + continue; + } - if('t' == transa_c || 'T' == transa_c) - transa = BLIS_TRANSPOSE; - else if('n' == transa_c || 'N' == transa_c) - transa = BLIS_NO_TRANSPOSE; - else - { - printf("Invalid entry for the argument 'transa':%c\n",transa_c); - continue; - } - - if('u' == diaga_c || 'U' == diaga_c) - diaga = BLIS_UNIT_DIAG; - else if('n' == diaga_c || 'N' == diaga_c) - diaga = BLIS_NONUNIT_DIAG; - else - { - printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); - continue; - } + if ('t' == transa_c || 'T' == transa_c) + transa = BLIS_TRANSPOSE; + else if ('n' == transa_c || 'N' == transa_c) + transa = BLIS_NO_TRANSPOSE; + else + { + printf("Invalid entry for the argument 'transa':%c\n", transa_c); + continue; + } + + if ('u' == diaga_c || 'U' == diaga_c) + diaga = BLIS_UNIT_DIAG; + else if ('n' == diaga_c || 'N' == diaga_c) + diaga = BLIS_NONUNIT_DIAG; + else + { + printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); + continue; + } #else - + fprintf(fout, "m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) + while (fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) { - - side = BLIS_LEFT; - //side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - transa = BLIS_NO_TRANSPOSE; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - diaga = BLIS_NONUNIT_DIAG; + transa = BLIS_NO_TRANSPOSE; + diaga = BLIS_NONUNIT_DIAG; #endif - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); + siz_t elem_size = bli_dt_size(dt); - siz_t elem_size = bli_dt_size( dt ); + cs_a = bli_align_dim_to_size(cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + cs_b = bli_align_dim_to_size(cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); - cs_a = bli_align_dim_to_size( cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - cs_b = bli_align_dim_to_size( cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = cs_a * sizeof(dt); - //Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = cs_a*sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_a += CACHE_LINE_SIZE / sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_a += CACHE_LINE_SIZE/sizeof(dt); + n_bytes = cs_b * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_b += CACHE_LINE_SIZE / sizeof(dt); - n_bytes = cs_b*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_b += CACHE_LINE_SIZE/sizeof(dt); + if (bli_is_left(side) && ((m > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows + if (bli_is_right(side) && ((n > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows - if(bli_is_left(side) && ((m > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if(bli_is_right(side) && ((n > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 1, m, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 1, m, &a); else - bli_obj_create( dt, n, n, 1, n, &a ); - bli_obj_create( dt, m, n, 1, m, &c ); - bli_obj_create( dt, m, n, 1, m, &c_save ); + bli_obj_create(dt, n, n, 1, n, &a); + bli_obj_create(dt, m, n, 1, m, &c); + bli_obj_create(dt, m, n, 1, m, &c_save); #else - for ( p = p_end; p >= p_begin; p -= p_inc ) + for (p = p_end; p >= p_begin; p -= p_inc) { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; + if (m_input < 0) + m = p * (dim_t)abs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)abs(n_input); + else + n = (dim_t)n_input; - - side = BLIS_LEFT; - //side = BLIS_RIGHT; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - transa = BLIS_NO_TRANSPOSE; + transa = BLIS_NO_TRANSPOSE; - diaga = BLIS_NONUNIT_DIAG; + diaga = BLIS_NONUNIT_DIAG; - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 0, 0, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 0, 0, &a); else - bli_obj_create( dt, n, n, 0, 0, &a ); - bli_obj_create( dt, m, n, 0, 0, &c ); - bli_obj_create( dt, m, n, 0, 0, &c_save ); + bli_obj_create(dt, n, n, 0, 0, &a); + bli_obj_create(dt, m, n, 0, 0, &c); + bli_obj_create(dt, m, n, 0, 0, &c_save); #endif - bli_randm( &a ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&c); - bli_obj_set_struc( BLIS_TRIANGULAR, &a ); - bli_obj_set_uplo( uploa, &a ); - bli_obj_set_conjtrans( transa, &a ); - bli_obj_set_diag( diaga, &a ); + bli_obj_set_struc(BLIS_TRIANGULAR, &a); + bli_obj_set_uplo(uploa, &a); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_diag(diaga, &a); // Randomize A and zero the unstored triangle to ensure the // implementation reads only from the stored region. - bli_randm( &a ); - bli_mktrim( &a ); + bli_randm(&a); + bli_mktrim(&a); // Load the diagonal of A to make it more likely to be invertible. - bli_shiftd( &BLIS_TWO, &a ); + bli_shiftd(&BLIS_TWO, &a); - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - bli_setsc( (2.0/1.0), 1.0, &alpha ); + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_setsc((2.0 / 1.0), 1.0, &alpha); + bli_copym(&c, &c_save); - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) + for (r = 0; r < n_repeats; ++r) { - bli_copym( &c_save, &c ); - + bli_copym(&c_save, &c); dtime = bli_clock(); - #ifdef PRINT - bli_invertd( &a ); - bli_printm( "a", &a, "%4.1f", "" ); - bli_invertd( &a ); - bli_printm( "c", &c, "%4.1f", "" ); + bli_invertd(&a); + bli_printm("a", &a, "%4.1f", ""); + bli_invertd(&a); + bli_printm("c", &c, "%4.1f", ""); #endif #ifdef BLIS - bli_trsm( side, - &alpha, - &a, - &c ); + bli_trsm(side, + &alpha, + &a, + &c); #else #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_UPLO cblas_uplo; - enum CBLAS_SIDE cblas_side; - enum CBLAS_DIAG cblas_diag; - - if ( bli_obj_row_stride( &c ) == 1 ) - cblas_order = CblasColMajor; - else - cblas_order = CblasRowMajor; - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if(bli_is_upper(uploa)) - cblas_uplo = CblasUpper; - else - cblas_uplo = CblasLower; - - if(bli_is_left(side)) - cblas_side = CblasLeft; - else - cblas_side = CblasRight; - - if(bli_is_unit_diag(diaga)) - cblas_diag = CblasUnit; - else - cblas_diag = CblasNonUnit; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_UPLO cblas_uplo; + enum CBLAS_SIDE cblas_side; + enum CBLAS_DIAG cblas_diag; + + if (bli_obj_row_stride(&c) == 1) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_upper(uploa)) + cblas_uplo = CblasUpper; + else + cblas_uplo = CblasLower; + + if (bli_is_left(side)) + cblas_side = CblasLeft; + else + cblas_side = CblasRight; + + if (bli_is_unit_diag(diaga)) + cblas_diag = CblasUnit; + else + cblas_diag = CblasNonUnit; #else - f77_char f77_transa; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + f77_char f77_transa; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); #endif - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + if (bli_is_float(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_strsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); + cblas_strsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); #else - strsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + strsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dtrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); -#else - dtrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + cblas_dtrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); +#else + dtrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ctrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ctrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ctrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ctrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ztrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ztrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ztrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ztrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - }else{ - printf("Invalid data type! Exiting!\n"); - exit(1); - } + } + else + { + printf("Invalid data type! Exiting!\n"); + exit(1); + } #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); + dtime_save = bli_clock_min_diff(dtime_save, dtime); } - if ( bli_is_left( side ) ) - gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); + if (bli_is_left(side)) + gflops = (1.0 * m * m * n) / (dtime_save * 1.0e9); else - gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); + gflops = (1.0 * m * n * n) / (dtime_save * 1.0e9); - if ( bli_is_complex( dt ) ) gflops *= 4.0; + if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf( "data_trsm_blis" ); + printf("data_trsm_blis"); #else - printf( "data_trsm_%s", BLAS ); + printf("data_trsm_%s", BLAS); #endif #ifdef FILE_IN_OUT #ifdef READ_ALL_PARAMS_FROM_FILE - printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n",side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); - fprintf(fout,"%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + fprintf(fout, "%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #else - printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); - fprintf(fout,"%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); + fprintf(fout, "%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #endif -fflush(fout); + fflush(fout); #else - printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); + bli_obj_free(&alpha); - bli_obj_free( &a ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); + bli_obj_free(&a); + bli_obj_free(&c); + bli_obj_free(&c_save); } #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - //bli_finalize(); + // bli_finalize(); return 0; } - diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index f03d094782..85866926dd 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -8,7 +8,7 @@ add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") if(ENABLE_OPENMP) - target_link_libraries(test_libblis "${OPENMP_PATH}/libomp.lib") + target_link_libraries(test_libblis OpenMP::OpenMP_CXX) endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") diff --git a/version b/version index 0c6173b5f1..944880fa15 100644 --- a/version +++ b/version @@ -1,2 +1 @@ -3.1.0 - +3.2.0