diff --git a/CMakeLists.txt b/CMakeLists.txt index e10f259c..fccfbac2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -455,9 +455,8 @@ else() endif() if (NOT LAPACK_FOUND) - message( "For [cz]rot, [cz]syr, and [cz]symv, BLAS++ requires a LAPACK library and none was found." - " Ensure that it is accessible in environment variables" - " $CPATH, $LIBRARY_PATH, and $LD_LIBRARY_PATH." ) + message( FATAL_ERROR + "BLAS++ requires LAPACK for [cz]rot, [cz]syr, [cz]symv." ) endif() # BLAS++ doesn't need LAPACKConfig.cmake, which checks version, XBLAS, LAPACKE. diff --git a/cmake/BLASFinder.cmake b/cmake/BLASFinder.cmake index 29871c73..0c05cca2 100644 --- a/cmake/BLASFinder.cmake +++ b/cmake/BLASFinder.cmake @@ -84,21 +84,10 @@ endfunction() # Setup. #---------------------------------------- compiler -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - set( gnu_compiler true ) -endif() - -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "IntelLLVM") - set( intelllvm_compiler true ) -endif() - -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") - set( intel_compiler true ) -endif() - -if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "XL|XLClang") - set( ibm_compiler true ) -endif() +string( COMPARE EQUAL "${CMAKE_CXX_COMPILER_ID}" "GNU" gnu_compiler) +string( COMPARE EQUAL "${CMAKE_CXX_COMPILER_ID}" "IntelLLVM" intelllvm_compiler ) +string( COMPARE EQUAL "${CMAKE_CXX_COMPILER_ID}" "Intel" intel_compiler ) +string( REGEX MATCH "XL|XLClang" ibm_compiler "${CMAKE_CXX_COMPILER_ID}" ) #---------------------------------------- Fortran manglings to test if (ibm_compiler) @@ -136,37 +125,13 @@ endif() #---------------------------------------- blas string( TOLOWER "${blas}" blas_ ) -if ("${blas_}" MATCHES "auto") - set( test_all true ) -endif() - -if ("${blas_}" MATCHES "acml") - set( test_acml true ) -endif() - -if ("${blas_}" MATCHES "apple|accelerate") - set( test_accelerate true ) -endif() - -if ("${blas_}" MATCHES "cray|libsci|default") - set( test_default true ) -endif() - -if ("${blas_}" MATCHES "ibm|essl") - set( test_essl true ) -endif() - -if ("${blas_}" MATCHES "intel|mkl") - set( test_mkl true ) -endif() - -if ("${blas_}" MATCHES "openblas") - set( test_openblas true ) -endif() - -if ("${blas_}" MATCHES "generic") - set( test_generic true ) -endif() +string( REGEX MATCH "auto|acml" test_acml "${blas_}" ) +string( REGEX MATCH "auto|ibm|essl" test_essl "${blas_}" ) +string( REGEX MATCH "auto|intel|mkl" test_mkl "${blas_}" ) +string( REGEX MATCH "auto|openblas" test_openblas "${blas_}" ) +string( REGEX MATCH "auto|generic" test_generic "${blas_}" ) +string( REGEX MATCH "auto|apple|accelerate" test_accelerate "${blas_}" ) +string( REGEX MATCH "auto|cray|libsci|default" test_default "${blas_}" ) message( DEBUG " BLAS_LIBRARIES = '${BLAS_LIBRARIES}' @@ -179,22 +144,13 @@ test_default = '${test_default}' test_essl = '${test_essl}' test_mkl = '${test_mkl}' test_openblas = '${test_openblas}' -test_generic = '${test_generic}' -test_all = '${test_all}'") +test_generic = '${test_generic}'" ) #---------------------------------------- blas_fortran string( TOLOWER "${blas_fortran}" blas_fortran_ ) -if ("${blas_fortran_}" MATCHES "gfortran") - set( test_gfortran true ) -endif() -if ("${blas_fortran_}" MATCHES "ifort") - set( test_ifort true ) -endif() -if ("${blas_fortran_}" MATCHES "auto") - set( test_gfortran true ) - set( test_ifort true ) -endif() +string( REGEX MATCH "auto|gfortran" test_gfortran "${blas_fortran_}" ) +string( REGEX MATCH "auto|ifort" test_ifort "${blas_fortran_}" ) message( DEBUG " blas_fortran = '${blas_fortran}' @@ -206,16 +162,10 @@ test_ifort = '${test_ifort}'") string( TOLOWER "${blas_int}" blas_int_ ) # This regex is similar to "\b(lp64|int)\b". -if ("${blas_int_}" MATCHES "(^|[^a-zA-Z0-9_])(lp64|int|int32|int32_t)($|[^a-zA-Z0-9_])") - set( test_int true ) -endif() -if ("${blas_int_}" MATCHES "(^|[^a-zA-Z0-9_])(ilp64|int64|int64_t)($|[^a-zA-Z0-9_])") - set( test_int64 true ) -endif() -if ("${blas_int_}" MATCHES "auto") - set( test_int true ) - set( test_int64 true ) -endif() +set( regex_int32 "(^|[^a-zA-Z0-9_])(auto|lp64|int|int32|int32_t)($|[^a-zA-Z0-9_])" ) +set( regex_int64 "(^|[^a-zA-Z0-9_])(auto|ilp64|int64|int64_t)($|[^a-zA-Z0-9_])" ) +string( REGEX MATCH ${regex_int32} test_int "${blas_int_}" ) +string( REGEX MATCH ${regex_int64} test_int64 "${blas_int_}" ) if (CMAKE_CROSSCOMPILING AND test_int AND test_int64) message( FATAL_ERROR " ${red}When cross-compiling, one must define either\n" @@ -232,17 +182,11 @@ test_int64 = '${test_int64}'") #---------------------------------------- blas_threaded string( TOLOWER "${blas_threaded}" blas_threaded_ ) -# This regex is similar to "\b(yes|...)\b". -if ("${blas_threaded_}" MATCHES "(^|[^a-zA-Z0-9_])(y|yes|true|on|1)($|[^a-zA-Z0-9_])") - set( test_threaded true ) -endif() -if ("${blas_threaded_}" MATCHES "(^|[^a-zA-Z0-9_])(n|no|false|off|0)($|[^a-zA-Z0-9_])") - set( test_sequential true ) -endif() -if ("${blas_threaded_}" MATCHES "auto") - set( test_threaded true ) - set( test_sequential true ) -endif() +# These regex are similar to "\b(yes|...)\b". +set( regex_yes "(^|[^a-zA-Z0-9_])(auto|y|yes|true|on|1)($|[^a-zA-Z0-9_])" ) +set( regex_no "(^|[^a-zA-Z0-9_])(auto|n|no|false|off|0)($|[^a-zA-Z0-9_])" ) +string( REGEX MATCH ${regex_yes} test_threaded "${blas_threaded_}" ) +string( REGEX MATCH ${regex_no} test_sequential "${blas_threaded_}" ) message( DEBUG " blas_threaded = '${blas_threaded}' @@ -270,14 +214,14 @@ if (test_blas_libraries) endif() #---------------------------------------- default; Cray libsci -if (test_all OR test_default) +if (test_default) list( APPEND blas_name_list "default (no library)" ) list( APPEND blas_libs_list " " ) # Use space so APPEND works later. debug_print_list( "default" ) endif() #---------------------------------------- Intel MKL -if (test_all OR test_mkl) +if (test_mkl) # todo: MKL_?(ROOT|DIR) if (test_threaded AND OpenMP_CXX_FOUND) if (test_gfortran AND gnu_compiler) @@ -366,7 +310,7 @@ if (test_all OR test_mkl) endif() # MKL #---------------------------------------- IBM ESSL -if (test_all OR test_essl) +if (test_essl) # todo: ESSL_?(ROOT|DIR) if (test_threaded) #message( "essl OpenMP_CXX_FOUND ${OpenMP_CXX_FOUND}" ) @@ -411,7 +355,7 @@ if (test_all OR test_essl) endif() #---------------------------------------- OpenBLAS -if (test_all OR test_openblas) +if (test_openblas) # todo: OPENBLAS_?(ROOT|DIR) list( APPEND blas_name_list "OpenBLAS" ) list( APPEND blas_libs_list "-lopenblas" ) @@ -419,14 +363,14 @@ if (test_all OR test_openblas) endif() #---------------------------------------- Apple Accelerate -if (test_all OR test_accelerate) +if (test_accelerate) list( APPEND blas_name_list "Apple Accelerate" ) list( APPEND blas_libs_list "-framework Accelerate" ) debug_print_list( "accelerate" ) endif() #---------------------------------------- generic -lblas -if (test_all OR test_generic) +if (test_generic) list( APPEND blas_name_list "generic" ) list( APPEND blas_libs_list "-lblas" ) debug_print_list( "generic" ) @@ -434,7 +378,7 @@ endif() #---------------------------------------- AMD ACML # Deprecated libraries last. -if (test_all OR test_acml) +if (test_acml) # todo: ACML_?(ROOT|DIR) if (test_threaded) list( APPEND blas_name_list "AMD ACML threaded" ) diff --git a/cmake/LAPACKFinder.cmake b/cmake/LAPACKFinder.cmake index 4f384243..a0394897 100644 --- a/cmake/LAPACKFinder.cmake +++ b/cmake/LAPACKFinder.cmake @@ -53,17 +53,8 @@ endif() #---------------------------------------- lapack string( TOLOWER "${lapack}" lapack_ ) -if ("${lapack_}" MATCHES "auto") - set( test_all true ) -endif() - -if ("${lapack_}" MATCHES "default") - set( test_default true ) -endif() - -if ("${lapack_}" MATCHES "generic") - set( test_generic true ) -endif() +string( REGEX MATCH "auto|default" test_default "${lapack_}" ) +string( REGEX MATCH "auto|generic" test_generic "${lapack_}" ) message( DEBUG " LAPACK_LIBRARIES = '${LAPACK_LIBRARIES}' @@ -71,8 +62,7 @@ lapack = '${lapack}' lapack_ = '${lapack_}' test_lapack_libraries = '${test_lapack_libraries}' test_default = '${test_default}' -test_generic = '${test_generic}' -test_all = '${test_all}'") +test_generic = '${test_generic}'" ) #------------------------------------------------------------------------------- # Build list of libraries to check. @@ -91,12 +81,12 @@ if (test_lapack_libraries) endif() #---------------------------------------- default (in BLAS library) -if (test_all OR test_default) +if (test_default) list( APPEND lapack_libs_list " " ) endif() #---------------------------------------- generic -llapack -if (test_all OR test_generic) +if (test_generic) list( APPEND lapack_libs_list "-llapack" ) endif() @@ -104,10 +94,10 @@ message( DEBUG "lapack_libs_list ${lapack_libs_list}" ) #------------------------------------------------------------------------------- # Check each LAPACK library. -# BLAS++ needs only a limited subset of LAPACK, so check for potrf (Cholesky). -# LAPACK++ checks for pstrf (Cholesky with pivoting) to make sure it is +# Checks for pstrf (Cholesky with pivoting) to make sure it is # a complete LAPACK library, since some BLAS libraries (ESSL, ATLAS) # contain only an optimized subset of LAPACK routines. +# ESSL lacks [cz]symv, [cz]syr. unset( LAPACK_FOUND CACHE ) unset( lapackpp_defs_ CACHE ) @@ -124,7 +114,7 @@ foreach (lapack_libs IN LISTS lapack_libs_list) try_run( run_result compile_result ${CMAKE_CURRENT_BINARY_DIR} SOURCES - "${CMAKE_CURRENT_SOURCE_DIR}/config/lapack_potrf.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/config/lapack_pstrf.cc" LINK_LIBRARIES # Use blaspp_libraries instead of blaspp, when SLATE includes # blaspp and lapackpp, so the blaspp library doesn't exist yet. @@ -139,11 +129,11 @@ foreach (lapack_libs IN LISTS lapack_libs_list) ) # For cross-compiling, if it links, assume the run is okay. if (CMAKE_CROSSCOMPILING AND compile_result) - message( DEBUG "cross: lapack_potrf" ) + message( DEBUG "cross: lapack_pstrf" ) set( run_result "0" CACHE STRING "" FORCE ) set( run_output "ok" CACHE STRING "" FORCE ) endif() - debug_try_run( "lapack_potrf.cc" "${compile_result}" "${compile_output}" + debug_try_run( "lapack_pstrf.cc" "${compile_result}" "${compile_output}" "${run_result}" "${run_output}" ) if (NOT compile_result) diff --git a/configure.py b/configure.py index 00ba0900..e821b70b 100755 --- a/configure.py +++ b/configure.py @@ -73,8 +73,9 @@ def main(): try: config.lapack.lapack() - except Error: - print_warn( 'BLAS++ needs LAPACK for testers.' ) + except Error as ex: + print_warn( 'BLAS++ requires LAPACK for [cz]rot, [cz]syr, [cz]symv.' ) + raise( ex ) config.gpu_blas() diff --git a/include/blas/fortran.h b/include/blas/fortran.h index 89973302..076fb490 100644 --- a/include/blas/fortran.h +++ b/include/blas/fortran.h @@ -533,46 +533,45 @@ void BLAS_dsymv_base( #endif ); -// [cz]symv moved to LAPACK++ since they are provided by LAPACK. -// #define BLAS_csymv_base BLAS_FORTRAN_NAME( csymv, CSYMV ) -// void BLAS_csymv_base( -// char const *uplo, -// blas_int const *n, -// blas_complex_float const *alpha, -// blas_complex_float const *A, blas_int const *lda, -// blas_complex_float const *x, blas_int const *incx, -// blas_complex_float const *beta, -// blas_complex_float *y, blas_int const *incy -// #ifdef BLAS_FORTRAN_STRLEN_END -// , size_t uplo_len -// #endif -// ); -// -// #define BLAS_zsymv_base BLAS_FORTRAN_NAME( zsymv, ZSYMV ) -// void BLAS_zsymv_base( -// char const *uplo, -// blas_int const *n, -// blas_complex_double const *alpha, -// blas_complex_double const *A, blas_int const *lda, -// blas_complex_double const *x, blas_int const *incx, -// blas_complex_double const *beta, -// blas_complex_double *y, blas_int const *incy -// #ifdef BLAS_FORTRAN_STRLEN_END -// , size_t uplo_len -// #endif -// ); +#define BLAS_csymv_base BLAS_FORTRAN_NAME( csymv, CSYMV ) +void BLAS_csymv_base( + char const *uplo, + blas_int const *n, + blas_complex_float const *alpha, + blas_complex_float const *A, blas_int const *lda, + blas_complex_float const *x, blas_int const *incx, + blas_complex_float const *beta, + blas_complex_float *y, blas_int const *incy + #ifdef BLAS_FORTRAN_STRLEN_END + , size_t uplo_len + #endif + ); + +#define BLAS_zsymv_base BLAS_FORTRAN_NAME( zsymv, ZSYMV ) +void BLAS_zsymv_base( + char const *uplo, + blas_int const *n, + blas_complex_double const *alpha, + blas_complex_double const *A, blas_int const *lda, + blas_complex_double const *x, blas_int const *incx, + blas_complex_double const *beta, + blas_complex_double *y, blas_int const *incy + #ifdef BLAS_FORTRAN_STRLEN_END + , size_t uplo_len + #endif + ); #ifdef BLAS_FORTRAN_STRLEN_END // Pass 1 for string lengths. #define BLAS_ssymv( ... ) BLAS_ssymv_base( __VA_ARGS__, 1 ) #define BLAS_dsymv( ... ) BLAS_dsymv_base( __VA_ARGS__, 1 ) - //#define BLAS_csymv( ... ) BLAS_csymv_base( __VA_ARGS__, 1 ) - //#define BLAS_zsymv( ... ) BLAS_zsymv_base( __VA_ARGS__, 1 ) + #define BLAS_csymv( ... ) BLAS_csymv_base( __VA_ARGS__, 1 ) + #define BLAS_zsymv( ... ) BLAS_zsymv_base( __VA_ARGS__, 1 ) #else #define BLAS_ssymv( ... ) BLAS_ssymv_base( __VA_ARGS__ ) #define BLAS_dsymv( ... ) BLAS_dsymv_base( __VA_ARGS__ ) - //#define BLAS_csymv( ... ) BLAS_csymv_base( __VA_ARGS__ ) - //#define BLAS_zsymv( ... ) BLAS_zsymv_base( __VA_ARGS__ ) + #define BLAS_csymv( ... ) BLAS_csymv_base( __VA_ARGS__ ) + #define BLAS_zsymv( ... ) BLAS_zsymv_base( __VA_ARGS__ ) #endif // ----------------------------------------------------------------------------- @@ -638,42 +637,41 @@ void BLAS_dsyr_base( #endif ); -// conflicts with current prototype in lapacke.h -//#define BLAS_csyr_base BLAS_FORTRAN_NAME( csyr, CSYR ) -//void BLAS_FORTRAN_NAME( csyr, CSYR )( -// char const *uplo, -// blas_int const *n, -// blas_complex_float const *alpha, -// blas_complex_float const *x, blas_int const *incx, -// blas_complex_float *A, blas_int const *lda -// #ifdef BLAS_FORTRAN_STRLEN_END -// , size_t uplo_len -// #endif -// ); -// -//#define BLAS_zsyr_base BLAS_FORTRAN_NAME( zsyr, ZSYR ) -//void BLAS_zsyr_base( -// char const *uplo, -// blas_int const *n, -// blas_complex_double const *alpha, -// blas_complex_double const *x, blas_int const *incx, -// blas_complex_double *A, blas_int const *lda -// #ifdef BLAS_FORTRAN_STRLEN_END -// , size_t uplo_len -// #endif -// ); +#define BLAS_csyr_base BLAS_FORTRAN_NAME( csyr, CSYR ) +void BLAS_FORTRAN_NAME( csyr, CSYR )( + char const *uplo, + blas_int const *n, + blas_complex_float const *alpha, + blas_complex_float const *x, blas_int const *incx, + blas_complex_float *A, blas_int const *lda + #ifdef BLAS_FORTRAN_STRLEN_END + , size_t uplo_len + #endif + ); + +#define BLAS_zsyr_base BLAS_FORTRAN_NAME( zsyr, ZSYR ) +void BLAS_zsyr_base( + char const *uplo, + blas_int const *n, + blas_complex_double const *alpha, + blas_complex_double const *x, blas_int const *incx, + blas_complex_double *A, blas_int const *lda + #ifdef BLAS_FORTRAN_STRLEN_END + , size_t uplo_len + #endif + ); #ifdef BLAS_FORTRAN_STRLEN_END // Pass 1 for string lengths. #define BLAS_ssyr( ... ) BLAS_ssyr_base( __VA_ARGS__, 1 ) #define BLAS_dsyr( ... ) BLAS_dsyr_base( __VA_ARGS__, 1 ) - //#define BLAS_csyr( ... ) BLAS_csyr_base( __VA_ARGS__, 1 ) - //#define BLAS_zsyr( ... ) BLAS_zsyr_base( __VA_ARGS__, 1 ) + #define BLAS_csyr( ... ) BLAS_csyr_base( __VA_ARGS__, 1 ) + #define BLAS_zsyr( ... ) BLAS_zsyr_base( __VA_ARGS__, 1 ) #else #define BLAS_ssyr( ... ) BLAS_ssyr_base( __VA_ARGS__ ) #define BLAS_dsyr( ... ) BLAS_dsyr_base( __VA_ARGS__ ) - //#define BLAS_csyr( ... ) BLAS_csyr_base( __VA_ARGS__ ) - //#define BLAS_zsyr( ... ) BLAS_zsyr_base( __VA_ARGS__ ) + #define BLAS_csyr( ... ) BLAS_csyr_base( __VA_ARGS__ ) + #define BLAS_zsyr( ... ) BLAS_zsyr_base( __VA_ARGS__ ) #endif // ----------------------------------------------------------------------------- diff --git a/src/symv.cc b/src/symv.cc index 0d0db677..64652a9c 100644 --- a/src/symv.cc +++ b/src/symv.cc @@ -48,6 +48,46 @@ inline void symv( &alpha, A, &lda, x, &incx, &beta, y, &incy ); } +//------------------------------------------------------------------------------ +/// Low-level overload wrapper calls Fortran, complex version. +/// @ingroup symv_internal +inline void symv( + char uplo, + blas_int n, + std::complex alpha, + std::complex const* A, blas_int lda, + std::complex const* x, blas_int incx, + std::complex beta, + std::complex* y, blas_int incy ) +{ + BLAS_csymv( &uplo, &n, + (blas_complex_float*) &alpha, + (blas_complex_float*) A, &lda, + (blas_complex_float*) x, &incx, + (blas_complex_float*) &beta, + (blas_complex_float*) y, &incy ); +} + +//------------------------------------------------------------------------------ +/// Low-level overload wrapper calls Fortran, complex version. +/// @ingroup symv_internal +inline void symv( + char uplo, + blas_int n, + std::complex alpha, + std::complex const* A, blas_int lda, + std::complex const* x, blas_int incx, + std::complex beta, + std::complex* y, blas_int incy ) +{ + BLAS_zsymv( &uplo, &n, + (blas_complex_double*) &alpha, + (blas_complex_double*) A, &lda, + (blas_complex_double*) x, &incx, + (blas_complex_double*) &beta, + (blas_complex_double*) y, &incy ); +} + } // namespace internal //============================================================================== @@ -148,4 +188,38 @@ void symv( alpha, A, lda, x, incx, beta, y, incy ); } +//------------------------------------------------------------------------------ +/// CPU, complex version. +/// @ingroup symv +void symv( + blas::Layout layout, + blas::Uplo uplo, + int64_t n, + std::complex alpha, + std::complex const* A, int64_t lda, + std::complex const* x, int64_t incx, + std::complex beta, + std::complex* y, int64_t incy ) +{ + impl::symv( layout, uplo, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + +//------------------------------------------------------------------------------ +/// CPU, complex version. +/// @ingroup symv +void symv( + blas::Layout layout, + blas::Uplo uplo, + int64_t n, + std::complex alpha, + std::complex const* A, int64_t lda, + std::complex const* x, int64_t incx, + std::complex beta, + std::complex* y, int64_t incy ) +{ + impl::symv( layout, uplo, n, + alpha, A, lda, x, incx, beta, y, incy ); +} + } // namespace blas diff --git a/src/syr.cc b/src/syr.cc index 95bb7a00..8768afec 100644 --- a/src/syr.cc +++ b/src/syr.cc @@ -42,6 +42,38 @@ inline void syr( BLAS_dsyr( &uplo, &n, &alpha, x, &incx, A, &lda ); } +//------------------------------------------------------------------------------ +/// Low-level overload wrapper calls Fortran, complex version. +/// @ingroup syr_internal +inline void syr( + char uplo, + blas_int n, + std::complex alpha, + std::complex const* x, blas_int incx, + std::complex* A, blas_int lda ) +{ + BLAS_csyr( &uplo, &n, + (blas_complex_float*) &alpha, + (blas_complex_float*) x, &incx, + (blas_complex_float*) A, &lda ); +} + +//------------------------------------------------------------------------------ +/// Low-level overload wrapper calls Fortran, complex version. +/// @ingroup syr_internal +inline void syr( + char uplo, + blas_int n, + std::complex alpha, + std::complex const* x, blas_int incx, + std::complex* A, blas_int lda ) +{ + BLAS_zsyr( &uplo, &n, + (blas_complex_double*) &alpha, + (blas_complex_double*) x, &incx, + (blas_complex_double*) A, &lda ); +} + } // namespace internal //============================================================================== @@ -133,4 +165,34 @@ void syr( alpha, x, incx, A, lda ); } +//------------------------------------------------------------------------------ +/// CPU, complex version. +/// @ingroup syr +void syr( + blas::Layout layout, + blas::Uplo uplo, + int64_t n, + std::complex alpha, + std::complex const* x, int64_t incx, + std::complex* A, int64_t lda ) +{ + impl::syr( layout, uplo, n, + alpha, x, incx, A, lda ); +} + +//------------------------------------------------------------------------------ +/// CPU, complex version. +/// @ingroup syr +void syr( + blas::Layout layout, + blas::Uplo uplo, + int64_t n, + std::complex alpha, + std::complex const* x, int64_t incx, + std::complex* A, int64_t lda ) +{ + impl::syr( layout, uplo, n, + alpha, x, incx, A, lda ); +} + } // namespace blas diff --git a/test/cblas_wrappers.cc b/test/cblas_wrappers.cc index 83c8b0ac..addf7d87 100644 --- a/test/cblas_wrappers.cc +++ b/test/cblas_wrappers.cc @@ -8,14 +8,15 @@ // get BLAS_FORTRAN_NAME and blas_int #include "blas/fortran.h" -// Including a variant of can cause conflicts in BLAS_*rot[g] -// Fortran prototypes, e.g., on macOS Ventura. So we define these without -// including their prototypes. -//#include "cblas_wrappers.hh" +// Circa 2022-12-22, there was a conflict in BLAS_*rot[g] when including +// both fortran.h and cblas.h (via cblas_wrappers.hh) on macOS Ventura. +// Can't replicate it now, and we need lapack_uplo_const() from +// cblas_wrappers.hh +#include "cblas_wrappers.hh" #include -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ void cblas_rotg( std::complex *a, std::complex *b, @@ -40,7 +41,7 @@ cblas_rotg( (blas_complex_double*) s ); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ void cblas_rot( int n, @@ -53,10 +54,8 @@ cblas_rot( blas_int incy_ = incy; BLAS_crot( &n_, - (blas_complex_float*) x, - &incx_, - (blas_complex_float*) y, - &incy_, + (blas_complex_float*) x, &incx_, + (blas_complex_float*) y, &incy_, &c, (blas_complex_float*) &s ); } @@ -73,10 +72,114 @@ cblas_rot( blas_int incy_ = incy; BLAS_zrot( &n_, - (blas_complex_double*) x, - &incx_, - (blas_complex_double*) y, - &incy_, + (blas_complex_double*) x, &incx_, + (blas_complex_double*) y, &incy_, &c, (blas_complex_double*) &s ); } + +//------------------------------------------------------------------------------ +void +cblas_symv( + CBLAS_LAYOUT layout, + CBLAS_UPLO uplo, + int n, + std::complex alpha, + std::complex const* A, int lda, + std::complex const* x, int incx, + std::complex beta, + std::complex* yref, int incy ) +{ + blas_int n_ = blas_int( n ); + blas_int incx_ = blas_int( incx ); + blas_int incy_ = blas_int( incy ); + blas_int lda_ = blas_int( lda ); + char uplo_ = lapack_uplo_const( uplo ); + if (layout == CblasRowMajor) { + uplo_ = (uplo == CblasUpper ? 'l' : 'u'); // switch upper <=> lower + } + BLAS_csymv( + &uplo_, &n_, + (blas_complex_float*) &alpha, + (blas_complex_float*) A, &lda_, + (blas_complex_float*) x, &incx_, + (blas_complex_float*) &beta, + (blas_complex_float*) yref, &incy_ + ); +} + +//------------------------------------------------------------------------------ +void +cblas_symv( + CBLAS_LAYOUT layout, + CBLAS_UPLO uplo, + int n, + std::complex alpha, + std::complex const* A, int lda, + std::complex const* x, int incx, + std::complex beta, + std::complex* yref, int incy ) +{ + blas_int n_ = blas_int( n ); + blas_int incx_ = blas_int( incx ); + blas_int incy_ = blas_int( incy ); + blas_int lda_ = blas_int( lda ); + char uplo_ = lapack_uplo_const( uplo ); + if (layout == CblasRowMajor) { + uplo_ = (uplo == CblasUpper ? 'l' : 'u'); // switch upper <=> lower + } + BLAS_zsymv( + &uplo_, &n_, + (blas_complex_double*) &alpha, + (blas_complex_double*) A, &lda_, + (blas_complex_double*) x, &incx_, + (blas_complex_double*) &beta, + (blas_complex_double*) yref, &incy_ + ); +} + +//------------------------------------------------------------------------------ +void +cblas_syr( + CBLAS_LAYOUT layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex* A, int lda ) +{ + blas_int n_ = blas_int( n ); + blas_int incx_ = blas_int( incx ); + blas_int lda_ = blas_int( lda ); + char uplo_ = lapack_uplo_const( uplo ); + if (layout == CblasRowMajor) { + uplo_ = (uplo == CblasUpper ? 'l' : 'u'); // switch upper <=> lower + } + BLAS_csyr( + &uplo_, &n_, + (blas_complex_float*) &alpha, + (blas_complex_float*) x, &incx_, + (blas_complex_float*) A, &lda_ + ); +} + +//------------------------------------------------------------------------------ +void +cblas_syr( + CBLAS_LAYOUT layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex* A, int lda ) +{ + blas_int n_ = blas_int( n ); + blas_int incx_ = blas_int( incx ); + blas_int lda_ = blas_int( lda ); + char uplo_ = lapack_uplo_const( uplo ); + if (layout == CblasRowMajor) { + uplo_ = (uplo == CblasUpper ? 'l' : 'u'); // switch upper <=> lower + } + BLAS_zsyr( + &uplo_, &n_, + (blas_complex_double*) &alpha, + (blas_complex_double*) x, &incx_, + (blas_complex_double*) A, &lda_ + ); +} diff --git a/test/cblas_wrappers.hh b/test/cblas_wrappers.hh index 60b9a8bc..d1df105d 100644 --- a/test/cblas_wrappers.hh +++ b/test/cblas_wrappers.hh @@ -503,6 +503,7 @@ cblas_rotg( } // CBLAS lacks [cz]rotg, but they're in Netlib BLAS. +// (Fixed in LAPACK PR #721, 2022-10) // Note c is real. void cblas_rotg( @@ -768,7 +769,27 @@ cblas_symv( } // LAPACK provides [cz]symv, CBLAS lacks them +void +cblas_symv( + CBLAS_LAYOUT layout, + CBLAS_UPLO uplo, + int n, + std::complex alpha, + std::complex const* A, int lda, + std::complex const* x, int incx, + std::complex beta, + std::complex* yref, int incy ); +void +cblas_symv( + CBLAS_LAYOUT layout, + CBLAS_UPLO uplo, + int n, + std::complex alpha, + std::complex const* A, int lda, + std::complex const* x, int incx, + std::complex beta, + std::complex* yref, int incy ); // ----------------------------------------------------------------------------- inline void @@ -1061,6 +1082,21 @@ cblas_syr( cblas_dsyr( layout, uplo, n, alpha, x, incx, A, lda ); } +// LAPACK provides [cz]syr, CBLAS lacks them +void +cblas_syr( + CBLAS_LAYOUT layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex* A, int lda ); + +void +cblas_syr( + CBLAS_LAYOUT layout, CBLAS_UPLO uplo, int n, + std::complex alpha, + std::complex const *x, int incx, + std::complex* A, int lda ); + // ----------------------------------------------------------------------------- inline void cblas_her2( diff --git a/test/run_tests.py b/test/run_tests.py index 6aa97d37..88b4223d 100755 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -303,8 +303,8 @@ def filter_csv( values, csv ): [ 'hemv', dtype + layout + align + uplo + n + incx + incy ], [ 'her', dtype + layout + align + uplo + n + incx ], [ 'her2', dtype + layout + align + uplo + n + incx + incy ], - [ 'symv', dtype_real + layout + align + uplo + n + incx + incy ], # complex is in lapack++ - [ 'syr', dtype_real + layout + align + uplo + n + incx ], # complex is in lapack++ + [ 'symv', dtype + layout + align + uplo + n + incx + incy ], + [ 'syr', dtype + layout + align + uplo + n + incx ], [ 'syr2', dtype + layout + align + uplo + n + incx + incy ], [ 'trmv', dtype + layout + align + uplo + trans + diag + n + incx ], [ 'trsv', dtype + layout + align + uplo + trans + diag + n + incx ], diff --git a/test/test_symv.cc b/test/test_symv.cc index 227fd4ef..6ecc6d99 100644 --- a/test/test_symv.cc +++ b/test/test_symv.cc @@ -158,8 +158,13 @@ void test_symv( Params& params, bool run ) break; case testsweeper::DataType::SingleComplex: + test_symv_work< std::complex, std::complex, + std::complex >( params, run ); + break; + case testsweeper::DataType::DoubleComplex: - throw blas::Error( "See symv< complex > in LAPACK++", __func__ ); + test_symv_work< std::complex, std::complex, + std::complex >( params, run ); break; default: diff --git a/test/test_syr.cc b/test/test_syr.cc index bd2c1792..9e7a7797 100644 --- a/test/test_syr.cc +++ b/test/test_syr.cc @@ -146,8 +146,11 @@ void test_syr( Params& params, bool run ) break; case testsweeper::DataType::SingleComplex: + test_syr_work< std::complex, std::complex >( params, run ); + break; + case testsweeper::DataType::DoubleComplex: - throw blas::Error( "See syr< complex > in LAPACK++", __func__ ); + test_syr_work< std::complex, std::complex >( params, run ); break; default: diff --git a/test/test_util.cc b/test/test_util.cc index 2a765bb4..1efb0de7 100644 --- a/test/test_util.cc +++ b/test/test_util.cc @@ -397,7 +397,7 @@ void test_device_routines() { printf( "%s\n", __func__ ); - int repeat = 4; + const int repeat = 4; double t; int device_cnt;