From c6974349536d13f43e689d5d418544cc9127f406 Mon Sep 17 00:00:00 2001 From: Omar Khalil Ahmed Date: Wed, 20 Dec 2023 14:47:28 -0800 Subject: [PATCH] Update with batched gemm for openacc ccsd implementation --- src/ccsd/GNUmakefile | 3 + src/ccsd/ccsd_trpdrv_openacc.F | 143 ++++++++++++++++++++++- src/ccsd/ccsd_trpdrv_openmp_imax.F | 182 +++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+), 5 deletions(-) diff --git a/src/ccsd/GNUmakefile b/src/ccsd/GNUmakefile index 056dc4521e..6433cf7d7d 100644 --- a/src/ccsd/GNUmakefile +++ b/src/ccsd/GNUmakefile @@ -148,6 +148,9 @@ ifdef USE_OPENACC_TRPDRV endif endif +ifdef USE_BATCHDGEMM_TRPDRV + FOPTIONS += -DUSE_BATCHDGEMM_TRPDRV +endif ifeq ($(ARMCI_NETWORK),MPI-PR) LIB_DEFINES += -DACC_STRIPS diff --git a/src/ccsd/ccsd_trpdrv_openacc.F b/src/ccsd/ccsd_trpdrv_openacc.F index 95859b5d80..b854f5efdb 100644 --- a/src/ccsd/ccsd_trpdrv_openacc.F +++ b/src/ccsd/ccsd_trpdrv_openacc.F @@ -110,9 +110,22 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, integer(INT32) :: shi type(cublasHandle) :: handle(8) integer(kind=cuda_stream_kind) :: stream(8) + +#ifdef USE_BATCHDGEMM_TRPDRV integer(INT32) :: nv4, no4 ! cublasDgemm requires 32-bit integers - integer(INT32), parameter :: cu_op_n = CUBLAS_OP_N - integer(INT32), parameter :: cu_op_t = CUBLAS_OP_T + integer(INT32), parameter :: batch1=4, batch2=4, batch3=8 + integer, parameter :: cu_op_n = CUBLAS_OP_N + integer, parameter :: cu_op_t = CUBLAS_OP_T + type(c_devptr), device, dimension(batch1) :: a_array1 + type(c_devptr), device, dimension(batch1) :: b_array1 + type(c_devptr), device, dimension(batch1) :: c_array1 + type(c_devptr), device, dimension(batch2) :: a_array2 + type(c_devptr), device, dimension(batch2) :: b_array2 + type(c_devptr), device, dimension(batch2) :: c_array2 + type(c_devptr), device, dimension(batch3) :: a_array3 + type(c_devptr), device, dimension(batch3) :: b_array3 + type(c_devptr), device, dimension(batch3) :: c_array3 +#endif ! nodes = ga_nnodes() me = ga_nodeid() @@ -133,6 +146,7 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, err = cublasSetStream(handle(shi), stream(shi)) if (err.ne.0) call errquit('cublasSetStream',err,UNKNOWN_ERR) end do + ! ! device-only temp arrays ! produced by DGEMM, consumed by TENGY @@ -404,7 +418,6 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, if (err.ne.0) then call errquit('cudaMemcpyAsync',err,UNKNOWN_ERR) endif - ! arrays and thus copies contribute to more than one CUBLAS call ! but the copies on streams 1:4 and 5:8 are separable. do shi=1,4 @@ -443,7 +456,127 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, else call qexit('accwait',0) endif +#ifdef USE_BATCHDGEMM_TRPDRV + + + ! first two batches can be executed + ! simultaneously + + + a_array1(1) = c_devloc(xJia(:)) + a_array1(2) = c_devloc(xKia(:)) + a_array1(3) = c_devloc(xJka(1+(k-klo)*lnvv:)) + a_array1(4) = c_devloc(xKka(1+(k-klo)*lnvv:)) + a_array2(1) = c_devloc(xJia(:)) + a_array2(2) = c_devloc(xKia(:)) + a_array2(3) = c_devloc(xJka(1+(k-klo)*lnvv:)) + a_array2(4) = c_devloc(xKka(1+(k-klo)*lnvv:)) + + b_array1(1) = c_devloc(xTkj(1+(k-klo)*lnvv:)) + b_array1(2) = c_devloc(xTkj(1+(k-klo)*lnvv:)) + b_array1(3) = c_devloc(xTij(:)) + b_array1(4) = c_devloc(xTij(:)) + + b_array2(1) = c_devloc(xTkj(1+(k-klo)*lnvv:)) + b_array2(2) = c_devloc(xTkj(1+(k-klo)*lnvv:)) + b_array2(3) = c_devloc(xTij(:)) + b_array2(4) = c_devloc(xTij(:)) + + c_array1(1) = c_devloc(f1n(:,:)) + c_array1(2) = c_devloc(f2n(:,:)) + c_array1(3) = c_devloc(f1t(:,:)) + c_array1(4) = c_devloc(f2t(:,:)) + + c_array2(1) = c_devloc(f3n(:,:)) + c_array2(2) = c_devloc(f4n(:,:)) + c_array2(3) = c_devloc(f3t(:,:)) + c_array2(4) = c_devloc(f4t(:,:)) + + ! third batch executed afterwards + + a_array3(1) = c_devloc(xTia(:)) + a_array3(2) = c_devloc(xXia(:)) + a_array3(3) = c_devloc(xTia(:)) + a_array3(4) = c_devloc(xXia(:)) + a_array3(5) = c_devloc(xTka(1+(k-klo)*lnov:)) + a_array3(6) = c_devloc(xXka(1+(k-klo)*lnov:)) + a_array3(7) = c_devloc(xTka(1+(k-klo)*lnov:)) + a_array3(8) = c_devloc(xXka(1+(k-klo)*lnov:)) + + b_array3(1) = c_devloc(xKkj(1+(k-klo)*lnov:)) + b_array3(2) = c_devloc(xKkj(1+(k-klo)*lnov:)) + b_array3(3) = c_devloc(xJkj(1+(k-klo)*lnov:)) + b_array3(4) = c_devloc(xJkj(1+(k-klo)*lnov:)) + b_array3(5) = c_devloc(xKij(:)) + b_array3(6) = c_devloc(xKij(:)) + b_array3(7) = c_devloc(xJij(:)) + b_array3(8) = c_devloc(xJij(:)) + + c_array3(1) = c_devloc(f1n(:,:)) + c_array3(2) = c_devloc(f2n(:,:)) + c_array3(3) = c_devloc(f3n(:,:)) + c_array3(4) = c_devloc(f4n(:,:)) + c_array3(5) = c_devloc(f1t(:,:)) + c_array3(6) = c_devloc(f2t(:,:)) + c_array3(7) = c_devloc(f3t(:,:)) + c_array3(8) = c_devloc(f4t(:,:)) + + err = cublasDgemmBatched_v2( + & handle(1), + & CUBLAS_OP_N, + & CUBLAS_OP_T, + & nv4, nv4, nv4, + & 1.0d0, + & a_array1, nv4, + & b_array1, nv4, + & 0.0d0, + & c_array1, nv4, + & batch1) + if (err.ne.0) then + call errquit('cublasDgemmBatched_v2',err, + & UNKNOWN_ERR) + endif + err = cublasDgemmBatched_v2( + & handle(2), + & CUBLAS_OP_N, + & CUBLAS_OP_N, + & nv4, nv4, nv4, + & 1.0d0, + & a_array2, nv4, + & b_array2, nv4, + & 0.0d0, + & c_array2, nv4, + & batch2) + + if (err.ne.0) then + call errquit('cublasDgemmBatched_v2',err, + & UNKNOWN_ERR) + endif + + err = cudaDeviceSynchronize() + if (err.ne.0) then + call errquit('cudaDeviceSync',err,UNKNOWN_ERR) + endif + + err = cublasDgemmBatched_v2( + & handle(1), + & CUBLAS_OP_N, + & CUBLAS_OP_N, + & nv4, nv4, no4, + & -1.0d0, + & a_array3, nv4, + & b_array3, no4, + & 1.0d0, + & c_array3, nv4, + & batch3) + if (err.ne.0) then + call errquit('cublasGemmBatched_v2', + & err,UNKNOWN_ERR) + endif + + +#else err = cublasDgemm_v2(handle(1), & cu_op_n,cu_op_t, & nv4,nv4,nv4,1.0d0, @@ -588,12 +721,11 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, if (err.ne.0) then call errquit('cublasDgemm_v2',err,UNKNOWN_ERR) endif - +#endif err = cudaDeviceSynchronize() if (err.ne.0) then call errquit('cudaDeviceSync',err,UNKNOWN_ERR) endif - ! 8 pairs of DGEMM w/ VVV and VVO cost, 2 for FMA dgemm_flops = 8*nvir*nvir*(nocc+nvir)*2 agg_flops = agg_flops + dgemm_flops @@ -761,6 +893,7 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb, deallocate( Tij, Tkj, Tia, Tka, Xia, Xka, & Jia, Jka, Kia, Kka, Jij, Jkj, Kij, Kkj, & Dja, Djka, Djia, stat=alloc_error) + if (alloc_error.ne.0) call errquit('free TKJKD',1,MA_ERR) deallocate( xTij, xTkj, xTia, xTka, xXia, xXka, & xJia, xJka, xKia, xKka, xJij, xJkj, xKij, xKkj, diff --git a/src/ccsd/ccsd_trpdrv_openmp_imax.F b/src/ccsd/ccsd_trpdrv_openmp_imax.F index fdd34d3cd4..ab58e0426d 100644 --- a/src/ccsd/ccsd_trpdrv_openmp_imax.F +++ b/src/ccsd/ccsd_trpdrv_openmp_imax.F @@ -103,6 +103,30 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, external nxtask integer :: dgemm_flops, tengy_flops double precision agg_flops + +#ifdef USE_BATCHDGEMM_TRPDRV + integer, parameter :: batch_size=8 + integer(KIND=C_SIZE_T) :: a_array1(batch_size) + integer(KIND=C_SIZE_T) :: b_array1(batch_size) + integer(KIND=C_SIZE_T) :: c_array1(batch_size) + integer(KIND=C_SIZE_T) :: a_array2(batch_size) + integer(KIND=C_SIZE_T) :: b_array2(batch_size) + integer(KIND=C_SIZE_T) :: c_array2(batch_size) + + integer, parameter :: grp_count=2 + integer :: group_size(grp_count) + character*1 :: transa_array(grp_count) + character*1 :: transb_array(grp_count) + integer :: m_array(grp_count) + integer :: n_array(grp_count) + integer :: k_array(grp_count) + integer :: lda_array(grp_count) + integer :: ldb_array(grp_count) + integer :: ldc_array(grp_count) + double precision :: alpha_array(grp_count) + double precision :: beta_array(grp_count) + integer :: group_count +#endif ! ! Dependencies (global array, local array, handle): ! @@ -638,6 +662,163 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, tc0 = util_wallsec() t_dgemm0 = util_wallsec() +#ifdef USE_BATCHDGEMM_TRPDRV + ! Update the arrays of pointers to arrays (a_array1, b_array1, + ! c_array1, a_array2, b_array2, and c_array2) on the device. + ! + !$omp target + a_array1(1) = LOC(Jia(1)) + a_array1(2) = LOC(Kia(1)) + a_array1(3) = LOC(Jka(1+(k-klo)*lnvv)) + a_array1(4) = LOC(Kka(1+(k-klo)*lnvv)) + a_array1(5) = LOC(Jia(1)) + a_array1(6) = LOC(Kia(1)) + a_array1(7) = LOC(Jka(1+(k-klo)*lnvv)) + a_array1(8) = LOC(Kka(1+(k-klo)*lnvv)) + + b_array1(1) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(2) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(3) = LOC(Tij(1)) + b_array1(4) = LOC(Tij(1)) + b_array1(5) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(6) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(7) = LOC(Tij(1)) + b_array1(8) = LOC(Tij(1)) + + c_array1(1) = LOC(f1n(1,1)) + c_array1(2) = LOC(f2n(1,1)) + c_array1(3) = LOC(f1t(1,1)) + c_array1(4) = LOC(f2t(1,1)) + c_array1(5) = LOC(f3n(1,1)) + c_array1(6) = LOC(f4n(1,1)) + c_array1(7) = LOC(f3t(1,1)) + c_array1(8) = LOC(f4t(1,1)) + + a_array2(1) = LOC(Tia(1)) + a_array2(2) = LOC(Xia(1)) + a_array2(3) = LOC(Tia(1)) + a_array2(4) = LOC(Xia(1)) + a_array2(5) = LOC(Tka(1+(k-klo)*lnov)) + a_array2(6) = LOC(Xka(1+(k-klo)*lnov)) + a_array2(7) = LOC(Tka(1+(k-klo)*lnov)) + a_array2(8) = LOC(Xka(1+(k-klo)*lnov)) + + b_array2(1) = LOC(Kkj(1+(k-klo)*lnov)) + b_array2(2) = LOC(Kkj(1+(k-klo)*lnov)) + b_array2(3) = LOC(Jkj(1+(k-klo)*lnov)) + b_array2(4) = LOC(Jkj(1+(k-klo)*lnov)) + b_array2(5) = LOC(Kij(1)) + b_array2(6) = LOC(Kij(1)) + b_array2(7) = LOC(Jij(1)) + b_array2(8) = LOC(Jij(1)) + + c_array2(1) = LOC(f1n(1,1)) + c_array2(2) = LOC(f2n(1,1)) + c_array2(3) = LOC(f3n(1,1)) + c_array2(4) = LOC(f4n(1,1)) + c_array2(5) = LOC(f1t(1,1)) + c_array2(6) = LOC(f2t(1,1)) + c_array2(7) = LOC(f3t(1,1)) + c_array2(8) = LOC(f4t(1,1)) + !$omp end target + + ! + ! First call to dgemm_batch (2 groups, each of size 4) + ! + + ! First group in first batch call. + ! 4 matrix multiplications. So group size is 4. + group_count = 2 + group_size(1) = 4 + + transa_array(1) = 'n' + transb_array(1) = 't' + + m_array(1) = nvir + n_array(1) = nvir + k_array(1) = nvir + + lda_array(1) = nvir + ldb_array(1) = nvir + ldc_array(1) = nvir + + alpha_array(1) = 1.0d0 + beta_array(1) = 0.0d0 + + ! Second group in first batch call. + ! 4 matrix multiplications. So group size is 4. + group_size(2) = 4 + + transa_array(2) = 'n' + transb_array(2) = 'n' + + m_array(2) = nvir + n_array(2) = nvir + k_array(2) = nvir + + lda_array(2) = nvir + ldb_array(2) = nvir + ldc_array(2) = nvir + + alpha_array(2) = 1.0d0 + beta_array(2) = 0.0d0 + + !$omp dispatch + call dgemm_batch(transa_array, + & transb_array, + & m_array, + & n_array, + & k_array, + & alpha_array, + & a_array1, + & lda_array, + & b_array1, + & ldb_array, + & beta_array, + & c_array1, + & ldc_array, + & group_count, + & group_size) + + ! + ! Second call to dgemm_batch (1 group of size 8) + ! + + ! 8 matrix multiplications. So group size is 8. + group_count = 1 + group_size(1) = 8 + + transa_array(1) = 'n' + transb_array(1) = 'n' + + m_array(1) = nvir + n_array(1) = nvir + k_array(1) = nocc + + lda_array(1) = nvir + ldb_array(1) = nocc + ldc_array(1) = nvir + + alpha_array(1) = -1.0d0 + beta_array(1) = 1.0d0 + + !$omp dispatch + call dgemm_batch(transa_array, + & transb_array, + & m_array, + & n_array, + & k_array, + & alpha_array, + & a_array2, + & lda_array, + & b_array2, + & ldb_array, + & beta_array, + & c_array2, + & ldc_array, + & group_count, + & group_size) +#else !$omp parallel sections num_threads(8) !$omp section @@ -744,6 +925,7 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, 1 Xka(1+(k-klo)*lnov),nvir,Jij,nocc,1.0d0, 2 f4t,nvir) !$omp end parallel sections +#endif #if 0 !$omp interop use(obj0)