diff --git a/lib/mpifx_comm.fpp b/lib/mpifx_comm.fpp index ac30042..57ab8d9 100644 --- a/lib/mpifx_comm.fpp +++ b/lib/mpifx_comm.fpp @@ -1,6 +1,6 @@ !> Contains the extended MPI communicator. module mpifx_comm_module - use mpi + use mpi_f08 use mpifx_helper_module, only : getoptarg, handle_errorflag implicit none private @@ -9,14 +9,19 @@ module mpifx_comm_module !> MPI communicator with some additional information. type mpifx_comm - integer :: id !< Communicator id. - integer :: size !< Nr. of processes (size). - integer :: rank !< Rank of the current process. - integer :: leadrank !< Index of the lead node. - logical :: lead !< True if current process is the lead (rank == 0). + integer :: id !< Communicator id. + type(mpi_comm) :: comm !< MPI communicator handle. + integer :: size !< Nr. of processes (size). + integer :: rank !< Rank of the current process. + integer :: leadrank !< Index of the lead node. + logical :: lead !< True if current process is the lead (rank == 0). contains + !> Initializes the MPI environment. - procedure :: init => mpifx_comm_init + procedure, private :: mpifx_comm_from_id + procedure, private :: mpifx_comm_from_type + + generic :: init => mpifx_comm_from_id, mpifx_comm_from_type !> Creates a new communicator by splitting the old one. procedure :: split => mpifx_comm_split @@ -31,27 +36,32 @@ module mpifx_comm_module contains - !> Initializes a communicator to contain all processes. + !> Initializes a communicator from mpi_comm type !! !! \param self Initialized instance on exit. - !! \param commid MPI Communicator ID (default: \c MPI_COMM_WORLD) - !! \param error Error flag on return containing the first error occuring + !! \param comm MPI Communicator (default: \c MPI_COMM_WORLD) + !! \param error Error flag on return containing the first error occurring !! during the calls mpi_comm_size and mpi_comm_rank. !! - subroutine mpifx_comm_init(self, commid, error) + subroutine mpifx_comm_from_type(self, comm, error) class(mpifx_comm), intent(out) :: self - integer, intent(in), optional :: commid + type(mpi_comm), intent(in), optional :: comm integer, intent(out), optional :: error integer :: error0 - call getoptarg(MPI_COMM_WORLD, self%id, commid) - call mpi_comm_size(self%id, self%size, error0) + if (present(comm)) then + self%comm = comm + else + self%comm = MPI_COMM_WORLD + end if + self%id = self%comm%mpi_val + call mpi_comm_size(self%comm, self%size, error0) call handle_errorflag(error0, "mpi_comm_size() in mpifx_comm_init()", error) if (error0 /= 0) then return end if - call mpi_comm_rank(self%id, self%rank, error0) + call mpi_comm_rank(self%comm, self%rank, error0) call handle_errorflag(error0, "mpi_comm_rank() in mpifx_comm_init()", error) if (error0 /= 0) then return @@ -59,7 +69,27 @@ contains self%leadrank = 0 self%lead = (self%rank == self%leadrank) - end subroutine mpifx_comm_init + end subroutine mpifx_comm_from_type + + + !> Initializes a communicator from a numerical id. + !! + !! \param self Initialized instance on exit. + !! \param commid Numerical MPI Communicator ID + !! \param error Error flag on return containing the first error occurring + !! during the calls mpi_comm_size and mpi_comm_rank. + !! + subroutine mpifx_comm_from_id(self, commid, error) + class(mpifx_comm), intent(out) :: self + integer, intent(in) :: commid + integer, intent(out), optional :: error + + type(mpi_comm) :: newcomm + + newcomm%mpi_val = commid + call self%mpifx_comm_from_type(newcomm, error) + + end subroutine mpifx_comm_from_id !> Creates a new communicators by splitting the old one. @@ -102,14 +132,15 @@ contains class(mpifx_comm), intent(out) :: newcomm integer, intent(out), optional :: error - integer :: error0, newcommid + integer :: error0 + type(mpi_comm) :: newmpicomm - call mpi_comm_split(self%id, splitkey, rankkey, newcommid, error0) + call mpi_comm_split(self%comm, splitkey, rankkey, newmpicomm, error0) call handle_errorflag(error0, "mpi_comm_split() in mpifx_comm_split()", error) if (error0 /= 0) then return end if - call newcomm%init(newcommid, error) + call newcomm%init(newmpicomm, error) end subroutine mpifx_comm_split @@ -150,14 +181,15 @@ contains class(mpifx_comm), intent(out) :: newcomm integer, intent(out), optional :: error - integer :: error0, newcommid + integer :: error0 + type(mpi_comm) :: newmpicomm - call mpi_comm_split_type(self%id, splittype, rankkey, MPI_INFO_NULL, newcommid, error0) + call mpi_comm_split_type(self%comm, splittype, rankkey, MPI_INFO_NULL, newmpicomm, error0) call handle_errorflag(error0, "mpi_comm_split_type() in mpifx_comm_split_type()", error) if (error0 /= 0) then return end if - call newcomm%init(newcommid, error) + call newcomm%init(newmpicomm, error) end subroutine mpifx_comm_split_type @@ -173,7 +205,8 @@ contains integer :: error - call mpi_comm_free(self%id, error) + call mpi_comm_free(self%comm, error) + self%id = self%comm%mpi_val end subroutine mpifx_comm_free