diff --git a/src/ccsd/GNUmakefile b/src/ccsd/GNUmakefile index 056dc4521e..03d0391351 100644 --- a/src/ccsd/GNUmakefile +++ b/src/ccsd/GNUmakefile @@ -128,12 +128,11 @@ ifdef USE_IMAX_OPENMP_TRPDRV OBJ_OPTIMIZE += ccsd_trpdrv_omp_reduce_f.o - FOPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-opt-data-sharing-for-reduction=false -mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -switch offload_modvars -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl -DMKL_ILP64 -I"${MKLROOT}/include" -I ${NWCHEM_TOP}/src/ccsd/module -mllvm -vpo-paropt-dispatch-codegen-version=1 -switch -use-host-usm-for-implicit-reduction-map + FOPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -switch offload_modvars -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl - COPTIONS:=$(filter-out -fopenmp,$(COPTIONS)) - COPTIONS:=$(filter-out -O1,$(COPTIONS)) + DEFINES += -DMKL_ILP64 - COPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-opt-data-sharing-for-reduction=false -mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl -DMKL_ILP64 -I"${MKLROOT}/include" -mllvm -vpo-paropt-dispatch-codegen-version=1 + INCLUDES += -I"${MKLROOT}/include" -I ${NWCHEM_TOP}/src/ccsd/module endif @@ -148,6 +147,9 @@ ifdef USE_OPENACC_TRPDRV endif endif +ifdef USE_BATCHDGEMM_TRPDRV + DEFINES += -DUSE_BATCHDGEMM_TRPDRV +endif ifeq ($(ARMCI_NETWORK),MPI-PR) LIB_DEFINES += -DACC_STRIPS diff --git a/src/ccsd/ccsd_trpdrv_openmp_imax.F b/src/ccsd/ccsd_trpdrv_openmp_imax.F index fdd34d3cd4..ab58e0426d 100644 --- a/src/ccsd/ccsd_trpdrv_openmp_imax.F +++ b/src/ccsd/ccsd_trpdrv_openmp_imax.F @@ -103,6 +103,30 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, external nxtask integer :: dgemm_flops, tengy_flops double precision agg_flops + +#ifdef USE_BATCHDGEMM_TRPDRV + integer, parameter :: batch_size=8 + integer(KIND=C_SIZE_T) :: a_array1(batch_size) + integer(KIND=C_SIZE_T) :: b_array1(batch_size) + integer(KIND=C_SIZE_T) :: c_array1(batch_size) + integer(KIND=C_SIZE_T) :: a_array2(batch_size) + integer(KIND=C_SIZE_T) :: b_array2(batch_size) + integer(KIND=C_SIZE_T) :: c_array2(batch_size) + + integer, parameter :: grp_count=2 + integer :: group_size(grp_count) + character*1 :: transa_array(grp_count) + character*1 :: transb_array(grp_count) + integer :: m_array(grp_count) + integer :: n_array(grp_count) + integer :: k_array(grp_count) + integer :: lda_array(grp_count) + integer :: ldb_array(grp_count) + integer :: ldc_array(grp_count) + double precision :: alpha_array(grp_count) + double precision :: beta_array(grp_count) + integer :: group_count +#endif ! ! Dependencies (global array, local array, handle): ! @@ -638,6 +662,163 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, tc0 = util_wallsec() t_dgemm0 = util_wallsec() +#ifdef USE_BATCHDGEMM_TRPDRV + ! Update the arrays of pointers to arrays (a_array1, b_array1, + ! c_array1, a_array2, b_array2, and c_array2) on the device. + ! + !$omp target + a_array1(1) = LOC(Jia(1)) + a_array1(2) = LOC(Kia(1)) + a_array1(3) = LOC(Jka(1+(k-klo)*lnvv)) + a_array1(4) = LOC(Kka(1+(k-klo)*lnvv)) + a_array1(5) = LOC(Jia(1)) + a_array1(6) = LOC(Kia(1)) + a_array1(7) = LOC(Jka(1+(k-klo)*lnvv)) + a_array1(8) = LOC(Kka(1+(k-klo)*lnvv)) + + b_array1(1) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(2) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(3) = LOC(Tij(1)) + b_array1(4) = LOC(Tij(1)) + b_array1(5) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(6) = LOC(Tkj(1+(k-klo)*lnvv)) + b_array1(7) = LOC(Tij(1)) + b_array1(8) = LOC(Tij(1)) + + c_array1(1) = LOC(f1n(1,1)) + c_array1(2) = LOC(f2n(1,1)) + c_array1(3) = LOC(f1t(1,1)) + c_array1(4) = LOC(f2t(1,1)) + c_array1(5) = LOC(f3n(1,1)) + c_array1(6) = LOC(f4n(1,1)) + c_array1(7) = LOC(f3t(1,1)) + c_array1(8) = LOC(f4t(1,1)) + + a_array2(1) = LOC(Tia(1)) + a_array2(2) = LOC(Xia(1)) + a_array2(3) = LOC(Tia(1)) + a_array2(4) = LOC(Xia(1)) + a_array2(5) = LOC(Tka(1+(k-klo)*lnov)) + a_array2(6) = LOC(Xka(1+(k-klo)*lnov)) + a_array2(7) = LOC(Tka(1+(k-klo)*lnov)) + a_array2(8) = LOC(Xka(1+(k-klo)*lnov)) + + b_array2(1) = LOC(Kkj(1+(k-klo)*lnov)) + b_array2(2) = LOC(Kkj(1+(k-klo)*lnov)) + b_array2(3) = LOC(Jkj(1+(k-klo)*lnov)) + b_array2(4) = LOC(Jkj(1+(k-klo)*lnov)) + b_array2(5) = LOC(Kij(1)) + b_array2(6) = LOC(Kij(1)) + b_array2(7) = LOC(Jij(1)) + b_array2(8) = LOC(Jij(1)) + + c_array2(1) = LOC(f1n(1,1)) + c_array2(2) = LOC(f2n(1,1)) + c_array2(3) = LOC(f3n(1,1)) + c_array2(4) = LOC(f4n(1,1)) + c_array2(5) = LOC(f1t(1,1)) + c_array2(6) = LOC(f2t(1,1)) + c_array2(7) = LOC(f3t(1,1)) + c_array2(8) = LOC(f4t(1,1)) + !$omp end target + + ! + ! First call to dgemm_batch (2 groups, each of size 4) + ! + + ! First group in first batch call. + ! 4 matrix multiplications. So group size is 4. + group_count = 2 + group_size(1) = 4 + + transa_array(1) = 'n' + transb_array(1) = 't' + + m_array(1) = nvir + n_array(1) = nvir + k_array(1) = nvir + + lda_array(1) = nvir + ldb_array(1) = nvir + ldc_array(1) = nvir + + alpha_array(1) = 1.0d0 + beta_array(1) = 0.0d0 + + ! Second group in first batch call. + ! 4 matrix multiplications. So group size is 4. + group_size(2) = 4 + + transa_array(2) = 'n' + transb_array(2) = 'n' + + m_array(2) = nvir + n_array(2) = nvir + k_array(2) = nvir + + lda_array(2) = nvir + ldb_array(2) = nvir + ldc_array(2) = nvir + + alpha_array(2) = 1.0d0 + beta_array(2) = 0.0d0 + + !$omp dispatch + call dgemm_batch(transa_array, + & transb_array, + & m_array, + & n_array, + & k_array, + & alpha_array, + & a_array1, + & lda_array, + & b_array1, + & ldb_array, + & beta_array, + & c_array1, + & ldc_array, + & group_count, + & group_size) + + ! + ! Second call to dgemm_batch (1 group of size 8) + ! + + ! 8 matrix multiplications. So group size is 8. + group_count = 1 + group_size(1) = 8 + + transa_array(1) = 'n' + transb_array(1) = 'n' + + m_array(1) = nvir + n_array(1) = nvir + k_array(1) = nocc + + lda_array(1) = nvir + ldb_array(1) = nocc + ldc_array(1) = nvir + + alpha_array(1) = -1.0d0 + beta_array(1) = 1.0d0 + + !$omp dispatch + call dgemm_batch(transa_array, + & transb_array, + & m_array, + & n_array, + & k_array, + & alpha_array, + & a_array2, + & lda_array, + & b_array2, + & ldb_array, + & beta_array, + & c_array2, + & ldc_array, + & group_count, + & group_size) +#else !$omp parallel sections num_threads(8) !$omp section @@ -744,6 +925,7 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb, 1 Xka(1+(k-klo)*lnov),nvir,Jij,nocc,1.0d0, 2 f4t,nvir) !$omp end parallel sections +#endif #if 0 !$omp interop use(obj0)