Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched dgemm to CCSD(T) for openmp offload and openacc implementations #980

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/ccsd/GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,11 @@ ifdef USE_IMAX_OPENMP_TRPDRV

OBJ_OPTIMIZE += ccsd_trpdrv_omp_reduce_f.o

FOPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-opt-data-sharing-for-reduction=false -mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -switch offload_modvars -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl -DMKL_ILP64 -I"${MKLROOT}/include" -I ${NWCHEM_TOP}/src/ccsd/module -mllvm -vpo-paropt-dispatch-codegen-version=1 -switch -use-host-usm-for-implicit-reduction-map
FOPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -switch offload_modvars -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl

COPTIONS:=$(filter-out -fopenmp,$(COPTIONS))
COPTIONS:=$(filter-out -O1,$(COPTIONS))
DEFINES += -DMKL_ILP64

COPTIONS += -O3 -fiopenmp -fopenmp-targets=spir64="-mllvm -vpo-paropt-opt-data-sharing-for-reduction=false -mllvm -vpo-paropt-atomic-free-reduction-par-global=false" -mllvm -vpo-paropt-atomic-free-reduction-slm=true -qmkl -DMKL_ILP64 -I"${MKLROOT}/include" -mllvm -vpo-paropt-dispatch-codegen-version=1
INCLUDES += -I"${MKLROOT}/include" -I ${NWCHEM_TOP}/src/ccsd/module

endif

Expand All @@ -148,6 +147,9 @@ ifdef USE_OPENACC_TRPDRV
endif
endif

ifdef USE_BATCHDGEMM_TRPDRV
DEFINES += -DUSE_BATCHDGEMM_TRPDRV
endif

ifeq ($(ARMCI_NETWORK),MPI-PR)
LIB_DEFINES += -DACC_STRIPS
Expand Down
182 changes: 182 additions & 0 deletions src/ccsd/ccsd_trpdrv_openmp_imax.F
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb,
external nxtask
integer :: dgemm_flops, tengy_flops
double precision agg_flops

#ifdef USE_BATCHDGEMM_TRPDRV
integer, parameter :: batch_size=8
integer(KIND=C_SIZE_T) :: a_array1(batch_size)
integer(KIND=C_SIZE_T) :: b_array1(batch_size)
integer(KIND=C_SIZE_T) :: c_array1(batch_size)
integer(KIND=C_SIZE_T) :: a_array2(batch_size)
integer(KIND=C_SIZE_T) :: b_array2(batch_size)
integer(KIND=C_SIZE_T) :: c_array2(batch_size)

integer, parameter :: grp_count=2
integer :: group_size(grp_count)
character*1 :: transa_array(grp_count)
character*1 :: transb_array(grp_count)
integer :: m_array(grp_count)
integer :: n_array(grp_count)
integer :: k_array(grp_count)
integer :: lda_array(grp_count)
integer :: ldb_array(grp_count)
integer :: ldc_array(grp_count)
double precision :: alpha_array(grp_count)
double precision :: beta_array(grp_count)
integer :: group_count
#endif
!
! Dependencies (global array, local array, handle):
!
Expand Down Expand Up @@ -638,6 +662,163 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb,
tc0 = util_wallsec()
t_dgemm0 = util_wallsec()

#ifdef USE_BATCHDGEMM_TRPDRV
! Update the arrays of pointers to arrays (a_array1, b_array1,
! c_array1, a_array2, b_array2, and c_array2) on the device.
!
!$omp target
a_array1(1) = LOC(Jia(1))
a_array1(2) = LOC(Kia(1))
a_array1(3) = LOC(Jka(1+(k-klo)*lnvv))
a_array1(4) = LOC(Kka(1+(k-klo)*lnvv))
a_array1(5) = LOC(Jia(1))
a_array1(6) = LOC(Kia(1))
a_array1(7) = LOC(Jka(1+(k-klo)*lnvv))
a_array1(8) = LOC(Kka(1+(k-klo)*lnvv))

b_array1(1) = LOC(Tkj(1+(k-klo)*lnvv))
b_array1(2) = LOC(Tkj(1+(k-klo)*lnvv))
b_array1(3) = LOC(Tij(1))
b_array1(4) = LOC(Tij(1))
b_array1(5) = LOC(Tkj(1+(k-klo)*lnvv))
b_array1(6) = LOC(Tkj(1+(k-klo)*lnvv))
b_array1(7) = LOC(Tij(1))
b_array1(8) = LOC(Tij(1))

c_array1(1) = LOC(f1n(1,1))
c_array1(2) = LOC(f2n(1,1))
c_array1(3) = LOC(f1t(1,1))
c_array1(4) = LOC(f2t(1,1))
c_array1(5) = LOC(f3n(1,1))
c_array1(6) = LOC(f4n(1,1))
c_array1(7) = LOC(f3t(1,1))
c_array1(8) = LOC(f4t(1,1))

a_array2(1) = LOC(Tia(1))
a_array2(2) = LOC(Xia(1))
a_array2(3) = LOC(Tia(1))
a_array2(4) = LOC(Xia(1))
a_array2(5) = LOC(Tka(1+(k-klo)*lnov))
a_array2(6) = LOC(Xka(1+(k-klo)*lnov))
a_array2(7) = LOC(Tka(1+(k-klo)*lnov))
a_array2(8) = LOC(Xka(1+(k-klo)*lnov))

b_array2(1) = LOC(Kkj(1+(k-klo)*lnov))
b_array2(2) = LOC(Kkj(1+(k-klo)*lnov))
b_array2(3) = LOC(Jkj(1+(k-klo)*lnov))
b_array2(4) = LOC(Jkj(1+(k-klo)*lnov))
b_array2(5) = LOC(Kij(1))
b_array2(6) = LOC(Kij(1))
b_array2(7) = LOC(Jij(1))
b_array2(8) = LOC(Jij(1))

c_array2(1) = LOC(f1n(1,1))
c_array2(2) = LOC(f2n(1,1))
c_array2(3) = LOC(f3n(1,1))
c_array2(4) = LOC(f4n(1,1))
c_array2(5) = LOC(f1t(1,1))
c_array2(6) = LOC(f2t(1,1))
c_array2(7) = LOC(f3t(1,1))
c_array2(8) = LOC(f4t(1,1))
!$omp end target

!
! First call to dgemm_batch (2 groups, each of size 4)
!

! First group in first batch call.
! 4 matrix multiplications. So group size is 4.
group_count = 2
group_size(1) = 4

transa_array(1) = 'n'
transb_array(1) = 't'

m_array(1) = nvir
n_array(1) = nvir
k_array(1) = nvir

lda_array(1) = nvir
ldb_array(1) = nvir
ldc_array(1) = nvir

alpha_array(1) = 1.0d0
beta_array(1) = 0.0d0

! Second group in first batch call.
! 4 matrix multiplications. So group size is 4.
group_size(2) = 4

transa_array(2) = 'n'
transb_array(2) = 'n'

m_array(2) = nvir
n_array(2) = nvir
k_array(2) = nvir

lda_array(2) = nvir
ldb_array(2) = nvir
ldc_array(2) = nvir

alpha_array(2) = 1.0d0
beta_array(2) = 0.0d0

!$omp dispatch
call dgemm_batch(transa_array,
& transb_array,
& m_array,
& n_array,
& k_array,
& alpha_array,
& a_array1,
& lda_array,
& b_array1,
& ldb_array,
& beta_array,
& c_array1,
& ldc_array,
& group_count,
& group_size)

!
! Second call to dgemm_batch (1 group of size 8)
!

! 8 matrix multiplications. So group size is 8.
group_count = 1
group_size(1) = 8

transa_array(1) = 'n'
transb_array(1) = 'n'

m_array(1) = nvir
n_array(1) = nvir
k_array(1) = nocc

lda_array(1) = nvir
ldb_array(1) = nocc
ldc_array(1) = nvir

alpha_array(1) = -1.0d0
beta_array(1) = 1.0d0

!$omp dispatch
call dgemm_batch(transa_array,
& transb_array,
& m_array,
& n_array,
& k_array,
& alpha_array,
& a_array2,
& lda_array,
& b_array2,
& ldb_array,
& beta_array,
& c_array2,
& ldc_array,
& group_count,
& group_size)
#else
!$omp parallel sections num_threads(8)

!$omp section
Expand Down Expand Up @@ -744,6 +925,7 @@ subroutine ccsd_trpdrv_offload_xe(t1,xeorb,
1 Xka(1+(k-klo)*lnov),nvir,Jij,nocc,1.0d0,
2 f4t,nvir)
!$omp end parallel sections
#endif

#if 0
!$omp interop use(obj0)
Expand Down
Loading