diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 986b53a..f373dc2 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -18,6 +18,7 @@ set(sources-fpp mpifx_recv.fpp mpifx_reduce.fpp mpifx_scatter.fpp + mpifx_scatterv.fpp mpifx_send.fpp) set(sources-f90-preproc) diff --git a/lib/make.deps b/lib/make.deps index 34e8f62..6a29eb7 100644 --- a/lib/make.deps +++ b/lib/make.deps @@ -36,8 +36,8 @@ mpifx_constants.o: $$(_modobj_mpi) mpifx_constants.o = mpifx_constants.o $($(_modobj_mpi)) _modobj_mpifx_constants_module = mpifx_constants.o -module.o: $$(_modobj_mpifx_send_module) $$(_modobj_mpifx_scatter_module) $$(_modobj_mpifx_allgather_module) $$(_modobj_mpifx_allgatherv_module) $$(_modobj_mpifx_finalize_module) $$(_modobj_mpifx_barrier_module) $$(_modobj_mpifx_get_processor_name_module) $$(_modobj_mpifx_abort_module) $$(_modobj_mpifx_init_module) $$(_modobj_mpifx_constants_module) $$(_modobj_mpifx_recv_module) $$(_modobj_mpifx_bcast_module) $$(_modobj_mpifx_gather_module) $$(_modobj_mpifx_gatherv_module) $$(_modobj_mpifx_allreduce_module) $$(_modobj_mpifx_reduce_module) $$(_modobj_mpifx_comm_module) -module.o = module.o $($(_modobj_mpifx_send_module)) $($(_modobj_mpifx_scatter_module)) $($(_modobj_mpifx_allgather_module)) $($(_modobj_mpifx_allgatherv_module)) $($(_modobj_mpifx_finalize_module)) $($(_modobj_mpifx_barrier_module)) $($(_modobj_mpifx_get_processor_name_module)) $($(_modobj_mpifx_abort_module)) $($(_modobj_mpifx_init_module)) $($(_modobj_mpifx_constants_module)) $($(_modobj_mpifx_recv_module)) $($(_modobj_mpifx_bcast_module)) $($(_modobj_mpifx_gather_module)) $($(_modobj_mpifx_gatherv_module)) $($(_modobj_mpifx_allreduce_module)) $($(_modobj_mpifx_reduce_module)) $($(_modobj_mpifx_comm_module)) +module.o: $$(_modobj_mpifx_send_module) $$(_modobj_mpifx_scatter_module) $$(_modobj_mpifx_scatterv_module) $$(_modobj_mpifx_allgather_module) $$(_modobj_mpifx_allgatherv_module) $$(_modobj_mpifx_finalize_module) $$(_modobj_mpifx_barrier_module) $$(_modobj_mpifx_get_processor_name_module) $$(_modobj_mpifx_abort_module) $$(_modobj_mpifx_init_module) $$(_modobj_mpifx_constants_module) $$(_modobj_mpifx_recv_module) $$(_modobj_mpifx_bcast_module) $$(_modobj_mpifx_gather_module) $$(_modobj_mpifx_gatherv_module) $$(_modobj_mpifx_allreduce_module) $$(_modobj_mpifx_reduce_module) $$(_modobj_mpifx_comm_module) +module.o = module.o $($(_modobj_mpifx_send_module)) $($(_modobj_mpifx_scatter_module)) $($(_modobj_mpifx_scatterv_module)) $($(_modobj_mpifx_allgather_module)) $($(_modobj_mpifx_allgatherv_module)) $($(_modobj_mpifx_finalize_module)) $($(_modobj_mpifx_barrier_module)) $($(_modobj_mpifx_get_processor_name_module)) $($(_modobj_mpifx_abort_module)) $($(_modobj_mpifx_init_module)) $($(_modobj_mpifx_constants_module)) $($(_modobj_mpifx_recv_module)) $($(_modobj_mpifx_bcast_module)) $($(_modobj_mpifx_gather_module)) $($(_modobj_mpifx_gatherv_module)) $($(_modobj_mpifx_allreduce_module)) $($(_modobj_mpifx_reduce_module)) $($(_modobj_mpifx_comm_module)) _modobj_libmpifx_module = module.o @@ -69,6 +69,10 @@ mpifx_scatter.o: $$(_modobj_mpifx_common_module) mpifx_scatter.o = mpifx_scatter.o $($(_modobj_mpifx_common_module)) _modobj_mpifx_scatter_module = mpifx_scatter.o +mpifx_scatterv.o: $$(_modobj_mpifx_common_module) +mpifx_scatterv.o = mpifx_scatterv.o $($(_modobj_mpifx_common_module)) +_modobj_mpifx_scatterv_module = mpifx_scatterv.o + mpifx_abort.o: $$(_modobj_mpifx_common_module) mpifx_abort.o = mpifx_abort.o $($(_modobj_mpifx_common_module)) _modobj_mpifx_abort_module = mpifx_abort.o diff --git a/lib/module.fpp b/lib/module.fpp index bdf5def..a121129 100644 --- a/lib/module.fpp +++ b/lib/module.fpp @@ -29,6 +29,7 @@ module libmpifx_module use mpifx_allgather_module use mpifx_allgatherv_module use mpifx_scatter_module + use mpifx_scatterv_module implicit none public diff --git a/lib/mpifx_scatterv.fpp b/lib/mpifx_scatterv.fpp new file mode 100644 index 0000000..ebe1a6e --- /dev/null +++ b/lib/mpifx_scatterv.fpp @@ -0,0 +1,217 @@ +#:include 'mpifx.fypp' +#:set TYPES = ALL_TYPES +#:set RANKS = range(1, MAX_RANK + 1) + +!> Contains wrapper for \c MPI_SCATTER +module mpifx_scatterv_module + use mpifx_common_module + implicit none + private + + public :: mpifx_scatterv + + !> scatters scalars/arrays of different lengths from a given node. + !! + !! \details All functions have the same argument list only differing in the + !! type and rank of the second and third arguments. The second and third + !! arguments can be of type integer (i), real (s), double precision (d), + !! complex (c), double complex (z) and logical (l). Their rank can vary from + !! zero (scalars) up to the maximum rank. Both arguments must be of same + !! type. The second argument must have the size of the third times the number + !! of processes taking part in the scattering. The second argument must have + !! either the same rank as the third one or one rank more. In latter case + !! the last dimension of it must be of the size of the number of processes + !! in the scatterving. + !! + !! \see MPI documentation (\c MPI_scatterv) + !! + !! Example: + !! + !! program test_scatterv + !! use libmpifx_module + !! implicit none + !! + !! type(mpifx_comm) :: mycomm + !! integer, allocatable :: send1(:) + !! integer, allocatable :: recv1(:) + !! integer, allocatable :: sendcounts(:) + !! integer :: ii, nsend + !! + !! call mpifx_init() + !! call mycomm%init() + !! + !! ! I1 -> I1 + !! allocate(recv1(mycomm%rank+1)) + !! recv1 = 0 + !! if (mycomm%master) then + !! ! send1 size is 1+2+3+...+mycomm%size + !! nsend = mycomm%size*(mycomm%size+1)/2 + !! allocate(send1(nsend)) + !! do ii = 1, nsend + !! send1(ii) = ii + !! end do + !! allocate(sendcounts(mycomm%size)) + !! do ii = 1, mycomm%size + !! sendcounts(ii) = ii + !! end do + !! else + !! allocate(send1(0)) + !! end if + !! + !! if (mycomm%master) then + !! write(*, *) mycomm%rank, "Send1 buffer:", send1(:) + !! end if + !! call mpifx_scatterv(mycomm, send1, sendcounts, recv1) + !! write(*, *) mycomm%rank, "Recv1 buffer:", recv1 + !! + !! call mpifx_finalize() + !! + !! end program test_scatterv + !! + interface mpifx_scatterv +#:for TYPE in TYPES + #:for RANK in RANKS + #:set TYPEABBREV = TYPE_ABBREVS[TYPE] + module procedure mpifx_scatterv_${TYPEABBREV}$${RANK}$${TYPEABBREV}$${RANK}$ + module procedure mpifx_scatterv_${TYPEABBREV}$${RANK}$${TYPEABBREV}$${RANK - 1}$ + #:endfor +#:endfor + end interface mpifx_scatterv + +contains + +#:def mpifx_scatterv_dr0_template(SUFFIX, TYPE, MPITYPE, RANK, HASLENGTH) + + #:assert RANK > 0 + + !> scatters object of variable length from one process (type ${SUFFIX}$). + !! + !! \param mycomm MPI communicator. + !! \param send Quantity to be sent for scattering. + !! \param sendcounts Counts of sent data from each process + !! \param recv Received data on receive node (undefined on other nodes) + !! \param displs Entry i specifies where to take data to send to rank i + !! (default: computed from sendcounts assuming order with rank) + !! \param root Root process for the result (default: mycomm%masterrank) + !! \param error Error code on exit. + !! + subroutine mpifx_scatterv_${SUFFIX}$(mycomm, send, sendcounts, recv, displs, root, error) + type(mpifx_comm), intent(in) :: mycomm + ${TYPE}$, intent(in) :: send${RANKSUFFIX(RANK)}$ + integer, intent(in) :: sendcounts(:) + ${TYPE}$, intent(out) :: recv${RANKSUFFIX(RANK)}$ + integer, intent(in), optional :: displs(:) + integer, intent(in), optional :: root + integer, intent(out), optional :: error + + integer :: root0, error0, ii + integer, allocatable :: displs0(:) + + #:set SIZE = 'size(recv)' + #:set COUNT = ('len(recv) * ' + SIZE if HASLENGTH else SIZE) + + @:ASSERT(.not. mycomm%master .or. size(send) == size(recv) * mycomm%size) + @:ASSERT(.not. mycomm%master& + & .or. size(send, dim=${RANK}$) == size(recv, dim=${RANK}$) * mycomm%size) + + call getoptarg(mycomm%masterrank, root0, root) + if (mycomm%rank == root0) then + if (present(displs)) then + @:ASSERT(size(displs) == mycomm%size) + allocate(displs0(mycomm%size)) + displs0(:) = displs + else + allocate(displs0(mycomm%size)) + displs0(1) = 0 + do ii = 2, mycomm%size + displs0(ii) = displs0(ii-1) + sendcounts(ii-1) + end do + end if + end if + call mpi_scatterv(send, sendcounts, displs0, ${MPITYPE}$, recv, ${SIZE}$, ${MPITYPE}$, root0,& + & mycomm%id, error0) + + call handle_errorflag(error0, "MPI_SCATTER in mpifx_scatterv_${SUFFIX}$", error) + + end subroutine mpifx_scatterv_${SUFFIX}$ + +#:enddef mpifx_scatterv_dr0_template + + +#:def mpifx_scatterv_dr1_template(SUFFIX, TYPE, MPITYPE, RANK, HASLENGTH) + + #:assert RANK > 0 + + !> Scatter results from one process (type ${SUFFIX}$). + !! + !! \param mycomm MPI communicator. + !! \param send Quantity to be sent for scattering. + !! \param sendcounts Counts of sent data from each process + !! \param recv Received data on receive node (indefined on other nodes) + !! \param displs Entry i specifies where to take data to send to rank i + !! (default: computed from sendcounts assuming order with rank) + !! \param root Root process for the result (default: mycomm%masterrank) + !! \param error Error code on exit. + !! + subroutine mpifx_scatterv_${SUFFIX}$(mycomm, send, sendcounts, recv, displs, root, error) + type(mpifx_comm), intent(in) :: mycomm + ${TYPE}$, intent(in) :: send${RANKSUFFIX(RANK)}$ + integer, intent(in) :: sendcounts(:) + ${TYPE}$, intent(out) :: recv${RANKSUFFIX(RANK - 1)}$ + integer, intent(in), optional :: displs(:) + integer, intent(in), optional :: root + integer, intent(out), optional :: error + + integer :: root0, error0, ii + integer, allocatable :: displs0(:) + + #:set SIZE = '1' if RANK == 1 else 'size(recv)' + #:set COUNT = ('len(recv) * ' + SIZE if HASLENGTH else SIZE) + + @:ASSERT(.not. mycomm%master .or. size(send) == ${SIZE}$ * mycomm%size) + @:ASSERT(.not. mycomm%master .or. size(send, dim=${RANK}$) == mycomm%size) + #:if HASLENGTH + @:ASSERT(.not. mycomm%master .or. len(send) == len(recv)) + #:endif + + call getoptarg(mycomm%masterrank, root0, root) + if (mycomm%rank == root0) then + if (present(displs)) then + @:ASSERT(size(displs) == mycomm%size) + allocate(displs0(mycomm%size)) + displs0(:) = displs + else + allocate(displs0(mycomm%size)) + displs0(1) = 0 + do ii = 2, mycomm%size + displs0(ii) = displs0(ii-1) + sendcounts(ii-1) + end do + end if + end if + + call mpi_scatterv(send, sendcounts, displs0, ${MPITYPE}$, recv, ${COUNT}$, ${MPITYPE}$, root0,& + & mycomm%id, error0) + call handle_errorflag(error0, "MPI_SCATTER in mpifx_scatterv_${SUFFIX}$", error) + + end subroutine mpifx_scatterv_${SUFFIX}$ + +#:enddef mpifx_scatterv_dr1_template + + +#:for TYPE in TYPES + #:for RANK in RANKS + + #:set FTYPE = FORTRAN_TYPES[TYPE] + #:set MPITYPE = MPI_TYPES[TYPE] + #:set HASLENGTH = HAS_LENGTH[TYPE] + + #:set SUFFIX = TYPE_ABBREVS[TYPE] + str(RANK) + TYPE_ABBREVS[TYPE] + str(RANK) + $:mpifx_scatterv_dr0_template(SUFFIX, FTYPE, MPITYPE, RANK, HASLENGTH) + + #:set SUFFIX = TYPE_ABBREVS[TYPE] + str(RANK) + TYPE_ABBREVS[TYPE] + str(RANK - 1) + $:mpifx_scatterv_dr1_template(SUFFIX, FTYPE, MPITYPE, RANK, HASLENGTH) + + #:endfor +#:endfor + +end module mpifx_scatterv_module diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 243a4f0..efaf04f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,8 @@ set(targets test_gather test_gatherv test_reduce - test_scatter) + test_scatter + test_scatterv) foreach(target IN LISTS targets) add_executable(${target} ${target}.f90) diff --git a/test/make.build b/test/make.build index b85ee87..4a0f8ef 100644 --- a/test/make.build +++ b/test/make.build @@ -19,7 +19,8 @@ .SUFFIXES: .f90 .o TARGETS = test_bcast test_send_recv test_comm_split test_reduce \ - test_allreduce test_gather test_allgather test_scatter + test_allreduce test_gather test_allgather test_scatter \ + test_scatterv all: $(TARGETS) @@ -71,3 +72,6 @@ test_allgather: $(test_allgather.o) test_scatter: $(test_scatter.o) $(link-target) + +test_scatterv: $(test_scatterv.o) + $(link-target) diff --git a/test/make.deps b/test/make.deps index ffbb367..e2d11be 100644 --- a/test/make.deps +++ b/test/make.deps @@ -21,6 +21,9 @@ test_bcast.o = test_bcast.o $($(_modobj_libmpifx_module)) test_scatter.o: $$(_modobj_libmpifx_module) test_scatter.o = test_scatter.o $($(_modobj_libmpifx_module)) +test_scatterv.o: $$(_modobj_libmpifx_module) +test_scatterv.o = test_scatterv.o $($(_modobj_libmpifx_module)) + test_comm_split.o: $$(_modobj_libmpifx_module) test_comm_split.o = test_comm_split.o $($(_modobj_libmpifx_module)) diff --git a/test/test_scatterv.f90 b/test/test_scatterv.f90 new file mode 100644 index 0000000..53b08c8 --- /dev/null +++ b/test/test_scatterv.f90 @@ -0,0 +1,85 @@ +program test_scatterv + use libmpifx_module + implicit none + + type(mpifx_comm) :: mycomm + integer, allocatable :: send1(:), send2(:,:) + integer :: recv0 + integer, allocatable :: recv1(:), sendcount(:), displs(:) + character(100) :: formstr + character(*), parameter :: label = "(I2.2,'-',I3.3,'|',1X" + integer :: ii + + call mpifx_init() + call mycomm%init() + + ! I1 -> I0 + if (mycomm%master) then + allocate(send1(mycomm%size)) + allocate(sendcount(mycomm%size)) + send1(:) = [ (ii, ii = 1, size(send1)) ] + sendcount(:) = 1 + write(formstr, "(A,I0,A)") "A,", size(send1), "(1X,I0))" + write(*, label // formstr) 1, mycomm%rank, "Send1 buffer:", send1 + else + allocate(send1(0)) + allocate(sendcount(0)) + end if + recv0 = 0 + call mpifx_scatterv(mycomm, send1, sendcount, recv0) + write(formstr, "(A,I0,A)") "A,", 1, "(1X,I0))" + write(*, label // formstr) 2, mycomm%rank, "Recv0 buffer:", recv0 + + ! I1 -> I1 + if (mycomm%master) then + deallocate(send1) + allocate(send1(2 * mycomm%size)) + sendcount(:) = 2 + send1(:) = [ (ii, ii = 1, size(send1)) ] + write(formstr, "(A,I0,A)") "A,", size(send1), "(1X,I0))" + write(*, label // formstr) 3, mycomm%rank, "Send1 buffer:", send1 + end if + allocate(recv1(2)) + recv1(:) = 0 + call mpifx_scatterv(mycomm, send1, sendcount, recv1) + write(formstr, "(A,I0,A)") "A,", size(recv1), "(1X,I0))" + write(*, label // formstr) 4, mycomm%rank, "Recv1 buffer:", recv1 + + ! I2 -> I1 + if (mycomm%master) then + allocate(send2(2, mycomm%size)) + sendcount(:) = 2 + send2(:,:) = reshape(send1, [ 2, mycomm%size ]) + write(formstr, "(A,I0,A)") "A,", size(send2), "(1X,I0))" + write(*, label // formstr) 5, mycomm%rank, & + & "Send2 buffer:", send2 + else + allocate(send2(0,0)) + end if + recv1(:) = 0 + call mpifx_scatterv(mycomm, send2, sendcount, recv1) + write(formstr, "(A,I0,A)") "A,", size(recv1), "(1X,I0))" + write(*, label // formstr) 6, mycomm%rank, & + & "Recv1 buffer:", recv1 + + ! I1 -> I1 + if (mycomm%master) then + deallocate(send1) + allocate(send1(2 * mycomm%size)) + send1(:) = [ (ii, ii = 1, size(send1)) ] + sendcount(:) = 1 + allocate(displs(mycomm%size)) + displs(:) = [ (ii, ii = 1, size(send1), 2) ] + write(formstr, "(A,I0,A)") "A,", size(send1), "(1X,I0))" + write(*, label // formstr) 7, mycomm%rank, "Send1 buffer:", send1 + end if + deallocate(recv1) + allocate(recv1(1)) + recv1(:) = 0 + call mpifx_scatterv(mycomm, send1, sendcount, recv1, displs=displs) + write(formstr, "(A,I0,A)") "A,", size(recv1), "(1X,I0))" + write(*, label // formstr) 8, mycomm%rank, "Recv1 buffer:", recv1 + + call mpifx_finalize() + +end program test_scatterv