diff --git a/lib/mpifx_constants.fpp b/lib/mpifx_constants.fpp index 2a0a22f..2266e01 100644 --- a/lib/mpifx_constants.fpp +++ b/lib/mpifx_constants.fpp @@ -7,6 +7,7 @@ module mpifx_constants_module public :: MPI_MAX, MPI_MIN, MPI_SUM, MPI_PROD public :: MPI_LAND, MPI_BAND, MPI_LOR, MPI_BOR, MPI_LXOR ,MPI_BXOR public :: MPI_MAXLOC, MPI_MINLOC + public :: MPI_MODE_NOSTORE, MPI_MODE_NOPUT, MPI_MODE_NOPRECEDE, MPI_MODE_NOSUCCEED public :: MPI_THREAD_SINGLE, MPI_THREAD_FUNNELED, MPI_THREAD_SERIALIZED, MPI_THREAD_MULTIPLE public :: MPI_COMM_TYPE_SHARED public :: MPIFX_UNHANDLED_ERROR, MPIFX_ASSERT_FAILED diff --git a/lib/mpifx_win.fpp b/lib/mpifx_win.fpp index 1d96892..c9f4bc9 100644 --- a/lib/mpifx_win.fpp +++ b/lib/mpifx_win.fpp @@ -25,15 +25,18 @@ module mpifx_win_module procedure, private :: mpifx_win_allocate_shared_${TYPE_ABBREVS[TYPE]}$ #:endfor - !> Locks a shared memory segment. + !> Locks a shared memory segment for remote access. procedure :: lock => mpifx_win_lock !> Unlocks a shared memory segment. procedure :: unlock => mpifx_win_unlock - !> Synchronizes shared memory across MPI ranks. + !> Synchronizes shared memory across MPI ranks after remote access. procedure :: sync => mpifx_win_sync + !> Ensures consistency of stores between fence calls. + procedure :: fence => mpifx_win_fence + !> Deallocates memory associated with a shared memory segment. procedure :: free => mpifx_win_free @@ -47,44 +50,58 @@ contains !! !! \param self Handle of the shared memory window on return. !! \param mycomm MPI communicator. - !! \param length Number of elements of type ${TYPE}$ in the shared memory window. - !! \param shared_data Pointer to the shared data array of length 'length' on return. + !! \param global_length Number of elements of type ${TYPE}$ in the entire shared memory window. + !! \param global_pointer Pointer to the shared data array of length 'global_length' on return. + !! \param local_length Number of elements of type ${TYPE}$ occupied by the current rank. + !! \param local_pointer Pointer to the local chunk of the data array of length 'local_length' on return. !! \param error Optional error code on return. !! !! \see MPI documentation (\c MPI_WIN_ALLOCATE_SHARED) !! - subroutine mpifx_win_allocate_shared_${SUFFIX}$(self, mycomm, length, shared_data, error) + subroutine mpifx_win_allocate_shared_${SUFFIX}$(self, mycomm, global_length, global_pointer,& + & local_length, local_pointer, error) class(mpifx_win), intent(out) :: self class(mpifx_comm), intent(in) :: mycomm - integer, intent(in) :: length - ${TYPE}$, pointer, intent(out) :: shared_data(:) + integer, intent(in) :: global_length + ${TYPE}$, pointer, intent(out) :: global_pointer(:) + integer, intent(in), optional :: local_length + ${TYPE}$, pointer, intent(out), optional :: local_pointer(:) integer, intent(out), optional :: error integer :: disp_unit, error0, error1 - integer(MPI_ADDRESS_KIND) :: local_length - type(c_ptr) :: baseptr + integer(MPI_ADDRESS_KIND) :: global_mem_size, local_mem_size + type(c_ptr) :: global_baseptr, local_baseptr - disp_unit = storage_size(shared_data) / 8 + disp_unit = storage_size(global_pointer) / 8 - local_length = 0 - if (mycomm%lead) then - local_length = int(length, kind=MPI_ADDRESS_KIND) * disp_unit + local_mem_size = 0 + if (present(local_length)) then + local_mem_size = int(local_length, kind=MPI_ADDRESS_KIND) * disp_unit + else if (mycomm%lead) then + local_mem_size = int(global_length, kind=MPI_ADDRESS_KIND) * disp_unit end if - call mpi_win_allocate_shared(local_length, disp_unit, MPI_INFO_NULL, mycomm%id, baseptr, self%id, error0) - call handle_errorflag(error0, "MPI_WIN_ALLOCATE_SHARED in mpifx_win_allocate_shared_${SUFFIX}$", error) + call mpi_win_allocate_shared(local_mem_size, disp_unit, MPI_INFO_NULL, mycomm%id, local_baseptr,& + & self%id, error0) + call handle_errorflag(error0, "MPI_WIN_ALLOCATE_SHARED in mpifx_win_allocate_shared_${SUFFIX}$",& + & error) - call mpi_win_shared_query(self%id, 0, local_length, disp_unit, baseptr, error1) - call handle_errorflag(error1, "MPI_WIN_SHARED_QUERY in mpifx_win_allocate_shared_${SUFFIX}$", error) + call mpi_win_shared_query(self%id, mycomm%leadrank, global_mem_size, disp_unit, global_baseptr,& + & error1) + call handle_errorflag(error1, "MPI_WIN_SHARED_QUERY in mpifx_win_allocate_shared_${SUFFIX}$",& + & error) self%comm_id = mycomm%id - call c_f_pointer(baseptr, shared_data, [length]) + call c_f_pointer(global_baseptr, global_pointer, [global_length]) + if (present(local_pointer)) then + call c_f_pointer(local_baseptr, local_pointer, [local_length]) + end if end subroutine mpifx_win_allocate_shared_${SUFFIX}$ #:enddef mpifx_win_allocate_shared_template - !> Locks a shared memory segment. + !> Locks a shared memory segment for remote access. Starts a remote access epoch. !! !! \param self Handle of the shared memory window. !! \param error Optional error code on return. @@ -102,7 +119,7 @@ contains end subroutine mpifx_win_lock - !> Unlocks a shared memory segment. + !> Unlocks a shared memory segment. Finishes a remote access epoch. !! !! \param self Handle of the shared memory window. !! \param error Optional error code on return. @@ -120,7 +137,8 @@ contains end subroutine mpifx_win_unlock - !> Synchronizes shared memory across MPI ranks. + !> Synchronizes shared memory across MPI ranks after remote access. + !> Completes all memory stores in a remote access epoch. !! !! \param self Handle of the shared memory window. !! \param error Optional error code on return. @@ -141,6 +159,31 @@ contains end subroutine mpifx_win_sync + !> Ensure consistency of stores between fence calls + !! + !! \param self Handle of the shared memory window. + !! \param assert Hint to the MPI library to assume certain condition (e.g., MPI_MODE_NOSTORE). + !! \param error Optional error code on return. + !! + !! \see MPI documentation (\c MPI_WIN_FENCE) + !! + subroutine mpifx_win_fence(self, assert, error) + class(mpifx_win), intent(inout) :: self + integer, intent(in), optional :: assert + integer, intent(out), optional :: error + + integer :: error0, assert_ + + assert_ = 0 + if (present(assert)) then + assert_ = assert + end if + + call mpi_win_fence(assert_, self%id, error0) + call handle_errorflag(error0, "MPI_WIN_FENCE in mpifx_win_fence", error) + + end subroutine mpifx_win_fence + !> Deallocates memory associated with a shared memory segment. !! !! \param self Handle of the shared memory window. diff --git a/test/test_win_shared_mem.f90 b/test/test_win_shared_mem.f90 index 048fda8..fbad591 100644 --- a/test/test_win_shared_mem.f90 +++ b/test/test_win_shared_mem.f90 @@ -4,8 +4,9 @@ program test_win_shared_mem type(mpifx_comm) :: globalcomm, nodecomm type(mpifx_win) :: win - integer, parameter :: length = 7 - integer, pointer :: data_pointer(:) + integer, parameter :: sample_value = 42, size_rank_0 = 7, size_rank_other = 4 + integer :: global_length, local_length, rank, ii + integer, pointer :: global_pointer(:), local_pointer(:) call mpifx_init() call globalcomm%init() @@ -13,20 +14,68 @@ program test_win_shared_mem ! Create a new communicator for all ranks on a node first call globalcomm%split_type(MPI_COMM_TYPE_SHARED, globalcomm%rank, nodecomm) - call win%allocate_shared(nodecomm, length, data_pointer) + if (nodecomm%lead) then + local_length = size_rank_0 + else + local_length = size_rank_other + end if + global_length = size_rank_0 + size_rank_other * (nodecomm%size - 1) + + call win%allocate_shared(nodecomm, global_length, global_pointer) call win%lock() ! Only rank 0 writes data into the array if (nodecomm%lead) then - data_pointer(:) = 42 + global_pointer(:) = sample_value end if call win%sync() call win%unlock() - ! All ranks on the node will read the same value - write(*, "(2(A,1X,I0,1X))") "ID:", nodecomm%rank, "VALUE:", data_pointer(1) + ! All ranks on the node will read the same value in the global array view + if (any(global_pointer(1:global_length) /= sample_value)) then + write(*, "(3(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(1), "EXPECTED:", sample_value + call mpifx_abort(globalcomm) + end if + + call win%free() + + ! Initialize again with specific local length + call win%allocate_shared(nodecomm, global_length, global_pointer, local_length, local_pointer) + + call win%fence(MPI_MODE_NOSTORE + MPI_MODE_NOPRECEDE) + + ! Only rank 0 writes data into the array + if (nodecomm%lead) then + global_pointer(:) = sample_value + end if + + call win%fence() + + ! All ranks on the node will read the same value in their local view + if (any(local_pointer(1:local_length) /= sample_value)) then + write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", local_pointer(1), "EXPECTED:", sample_value + call mpifx_abort(globalcomm) + end if + + ! Now let all ranks write something into their local chunk + local_pointer(1:local_length) = nodecomm%rank + + call win%fence() + + ! All ranks should now read the correct global values + if (any(global_pointer(1:size_rank_0) /= 0)) then + write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(1), "EXPECTED:", 0 + call mpifx_abort(globalcomm) + end if + do rank = 1, nodecomm%size - 1 + ii = size_rank_0 + 1 + size_rank_other * (rank - 1) + if (any(global_pointer(ii:ii+size_rank_other-1) /= rank)) then + write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(ii), "EXPECTED:", rank + call mpifx_abort(globalcomm) + end if + end do call win%free() call mpifx_finalize()