Skip to content

Commit

Permalink
Added option to allocate shared memory on each rank
Browse files Browse the repository at this point in the history
  • Loading branch information
terminationshock committed Feb 17, 2023
1 parent da51073 commit 7fd9ac4
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 27 deletions.
1 change: 1 addition & 0 deletions lib/mpifx_constants.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module mpifx_constants_module
public :: MPI_MAX, MPI_MIN, MPI_SUM, MPI_PROD
public :: MPI_LAND, MPI_BAND, MPI_LOR, MPI_BOR, MPI_LXOR ,MPI_BXOR
public :: MPI_MAXLOC, MPI_MINLOC
public :: MPI_MODE_NOSTORE, MPI_MODE_NOPUT, MPI_MODE_NOPRECEDE, MPI_MODE_NOSUCCEED
public :: MPI_THREAD_SINGLE, MPI_THREAD_FUNNELED, MPI_THREAD_SERIALIZED, MPI_THREAD_MULTIPLE
public :: MPI_COMM_TYPE_SHARED
public :: MPIFX_UNHANDLED_ERROR, MPIFX_ASSERT_FAILED
Expand Down
85 changes: 64 additions & 21 deletions lib/mpifx_win.fpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ module mpifx_win_module
procedure, private :: mpifx_win_allocate_shared_${TYPE_ABBREVS[TYPE]}$
#:endfor

!> Locks a shared memory segment.
!> Locks a shared memory segment for remote access.
procedure :: lock => mpifx_win_lock

!> Unlocks a shared memory segment.
procedure :: unlock => mpifx_win_unlock

!> Synchronizes shared memory across MPI ranks.
!> Synchronizes shared memory across MPI ranks after remote access.
procedure :: sync => mpifx_win_sync

!> Ensures consistency of stores between fence calls.
procedure :: fence => mpifx_win_fence

!> Deallocates memory associated with a shared memory segment.
procedure :: free => mpifx_win_free

Expand All @@ -47,44 +50,58 @@ contains
!!
!! \param self Handle of the shared memory window on return.
!! \param mycomm MPI communicator.
!! \param length Number of elements of type ${TYPE}$ in the shared memory window.
!! \param shared_data Pointer to the shared data array of length 'length' on return.
!! \param global_length Number of elements of type ${TYPE}$ in the entire shared memory window.
!! \param global_pointer Pointer to the shared data array of length 'global_length' on return.
!! \param local_length Number of elements of type ${TYPE}$ occupied by the current rank.
!! \param local_pointer Pointer to the local chunk of the data array of length 'local_length' on return.
!! \param error Optional error code on return.
!!
!! \see MPI documentation (\c MPI_WIN_ALLOCATE_SHARED)
!!
subroutine mpifx_win_allocate_shared_${SUFFIX}$(self, mycomm, length, shared_data, error)
subroutine mpifx_win_allocate_shared_${SUFFIX}$(self, mycomm, global_length, global_pointer,&
& local_length, local_pointer, error)
class(mpifx_win), intent(out) :: self
class(mpifx_comm), intent(in) :: mycomm
integer, intent(in) :: length
${TYPE}$, pointer, intent(out) :: shared_data(:)
integer, intent(in) :: global_length
${TYPE}$, pointer, intent(out) :: global_pointer(:)
integer, intent(in), optional :: local_length
${TYPE}$, pointer, intent(out), optional :: local_pointer(:)
integer, intent(out), optional :: error

integer :: disp_unit, error0, error1
integer(MPI_ADDRESS_KIND) :: local_length
type(c_ptr) :: baseptr
integer(MPI_ADDRESS_KIND) :: global_mem_size, local_mem_size
type(c_ptr) :: global_baseptr, local_baseptr

disp_unit = storage_size(shared_data) / 8
disp_unit = storage_size(global_pointer) / 8

local_length = 0
if (mycomm%lead) then
local_length = int(length, kind=MPI_ADDRESS_KIND) * disp_unit
local_mem_size = 0
if (present(local_length)) then
local_mem_size = int(local_length, kind=MPI_ADDRESS_KIND) * disp_unit
else if (mycomm%lead) then
local_mem_size = int(global_length, kind=MPI_ADDRESS_KIND) * disp_unit
end if

call mpi_win_allocate_shared(local_length, disp_unit, MPI_INFO_NULL, mycomm%id, baseptr, self%id, error0)
call handle_errorflag(error0, "MPI_WIN_ALLOCATE_SHARED in mpifx_win_allocate_shared_${SUFFIX}$", error)
call mpi_win_allocate_shared(local_mem_size, disp_unit, MPI_INFO_NULL, mycomm%id, local_baseptr,&
& self%id, error0)
call handle_errorflag(error0, "MPI_WIN_ALLOCATE_SHARED in mpifx_win_allocate_shared_${SUFFIX}$",&
& error)

call mpi_win_shared_query(self%id, 0, local_length, disp_unit, baseptr, error1)
call handle_errorflag(error1, "MPI_WIN_SHARED_QUERY in mpifx_win_allocate_shared_${SUFFIX}$", error)
call mpi_win_shared_query(self%id, mycomm%leadrank, global_mem_size, disp_unit, global_baseptr,&
& error1)
call handle_errorflag(error1, "MPI_WIN_SHARED_QUERY in mpifx_win_allocate_shared_${SUFFIX}$",&
& error)

self%comm_id = mycomm%id
call c_f_pointer(baseptr, shared_data, [length])
call c_f_pointer(global_baseptr, global_pointer, [global_length])
if (present(local_pointer)) then
call c_f_pointer(local_baseptr, local_pointer, [local_length])
end if

end subroutine mpifx_win_allocate_shared_${SUFFIX}$

#:enddef mpifx_win_allocate_shared_template

!> Locks a shared memory segment.
!> Locks a shared memory segment for remote access. Starts a remote access epoch.
!!
!! \param self Handle of the shared memory window.
!! \param error Optional error code on return.
Expand All @@ -102,7 +119,7 @@ contains

end subroutine mpifx_win_lock

!> Unlocks a shared memory segment.
!> Unlocks a shared memory segment. Finishes a remote access epoch.
!!
!! \param self Handle of the shared memory window.
!! \param error Optional error code on return.
Expand All @@ -120,7 +137,8 @@ contains

end subroutine mpifx_win_unlock

!> Synchronizes shared memory across MPI ranks.
!> Synchronizes shared memory across MPI ranks after remote access.
!> Completes all memory stores in a remote access epoch.
!!
!! \param self Handle of the shared memory window.
!! \param error Optional error code on return.
Expand All @@ -141,6 +159,31 @@ contains

end subroutine mpifx_win_sync

!> Ensure consistency of stores between fence calls
!!
!! \param self Handle of the shared memory window.
!! \param assert Hint to the MPI library to assume certain condition (e.g., MPI_MODE_NOSTORE).
!! \param error Optional error code on return.
!!
!! \see MPI documentation (\c MPI_WIN_FENCE)
!!
subroutine mpifx_win_fence(self, assert, error)
class(mpifx_win), intent(inout) :: self
integer, intent(in), optional :: assert
integer, intent(out), optional :: error

integer :: error0, assert_

assert_ = 0
if (present(assert)) then
assert_ = assert
end if

call mpi_win_fence(assert_, self%id, error0)
call handle_errorflag(error0, "MPI_WIN_FENCE in mpifx_win_fence", error)

end subroutine mpifx_win_fence

!> Deallocates memory associated with a shared memory segment.
!!
!! \param self Handle of the shared memory window.
Expand Down
61 changes: 55 additions & 6 deletions test/test_win_shared_mem.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,78 @@ program test_win_shared_mem

type(mpifx_comm) :: globalcomm, nodecomm
type(mpifx_win) :: win
integer, parameter :: length = 7
integer, pointer :: data_pointer(:)
integer, parameter :: sample_value = 42, size_rank_0 = 7, size_rank_other = 4
integer :: global_length, local_length, rank, ii
integer, pointer :: global_pointer(:), local_pointer(:)

call mpifx_init()
call globalcomm%init()

! Create a new communicator for all ranks on a node first
call globalcomm%split_type(MPI_COMM_TYPE_SHARED, globalcomm%rank, nodecomm)

call win%allocate_shared(nodecomm, length, data_pointer)
if (nodecomm%lead) then
local_length = size_rank_0
else
local_length = size_rank_other
end if
global_length = size_rank_0 + size_rank_other * (nodecomm%size - 1)

call win%allocate_shared(nodecomm, global_length, global_pointer)

call win%lock()

! Only rank 0 writes data into the array
if (nodecomm%lead) then
data_pointer(:) = 42
global_pointer(:) = sample_value
end if

call win%sync()
call win%unlock()

! All ranks on the node will read the same value
write(*, "(2(A,1X,I0,1X))") "ID:", nodecomm%rank, "VALUE:", data_pointer(1)
! All ranks on the node will read the same value in the global array view
if (any(global_pointer(1:global_length) /= sample_value)) then
write(*, "(3(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(1), "EXPECTED:", sample_value
call mpifx_abort(globalcomm)
end if

call win%free()

! Initialize again with specific local length
call win%allocate_shared(nodecomm, global_length, global_pointer, local_length, local_pointer)

call win%fence(MPI_MODE_NOSTORE + MPI_MODE_NOPRECEDE)

! Only rank 0 writes data into the array
if (nodecomm%lead) then
global_pointer(:) = sample_value
end if

call win%fence()

! All ranks on the node will read the same value in their local view
if (any(local_pointer(1:local_length) /= sample_value)) then
write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", local_pointer(1), "EXPECTED:", sample_value
call mpifx_abort(globalcomm)
end if

! Now let all ranks write something into their local chunk
local_pointer(1:local_length) = nodecomm%rank

call win%fence()

! All ranks should now read the correct global values
if (any(global_pointer(1:size_rank_0) /= 0)) then
write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(1), "EXPECTED:", 0
call mpifx_abort(globalcomm)
end if
do rank = 1, nodecomm%size - 1
ii = size_rank_0 + 1 + size_rank_other * (rank - 1)
if (any(global_pointer(ii:ii+size_rank_other-1) /= rank)) then
write(*, "(2(A,1X,I0,1X))") "ERROR! ID:", nodecomm%rank, "VALUE:", global_pointer(ii), "EXPECTED:", rank
call mpifx_abort(globalcomm)
end if
end do

call win%free()
call mpifx_finalize()
Expand Down

0 comments on commit 7fd9ac4

Please sign in to comment.