Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CudaIpc 2/3]: Ipc handle exchange #3910

Open
wants to merge 4 commits into
base: add_backend_type_to_p2p_comm
Choose a base branch
from

Conversation

samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Feb 17, 2025

On top of

prerequesite to:

What

  • Set up the infrastructure needed for ipc handle exchange and caching
  • Add an Expr node hir::ShareMemHandles to represent this op. We cannot embed the op in the Send/Recv semantics because we need to group the handle exchange between matching sends and recv to avoid deadlocks

How

Most of the implementation is in multidevice/ipc_handle.cpp

  • Define the class IpcHandle representing the ipc handle that is exchanged. This class is supplemented with a semaphore, which is a local cuda buffer allocated on the exporter's device.
  • Define IpcHandleCache which handles exchanging and caching the ipc handles. Caching is made on with respect to a combination of runtime and symbolic ingredients: (runtime peer, at::Tensor, Expr*). This caching allows to have arbitrary number of p2p comms between pairs of ranks.

Copy link

github-actions bot commented Feb 17, 2025

Review updated until commit c047576

Description

  • Added ShareMemHandles class for handling shared memory IPC.

  • Implemented IpcHandle and IpcHandleCache for CUDA IPC memory management.

  • Updated HostIrEvaluator to handle ShareMemHandles.

  • Included IpcHandle in CMakeLists.txt for compilation.


Changes walkthrough 📝

Relevant files
Enhancement
9 files
executor.cpp
Added handler for ShareMemHandles                                               
+5/-0     
host_ir.cpp
Introduced ShareMemHandles class                                                 
+28/-0   
ipc_handle.cpp
Implemented IPC handle management                                               
+150/-0 
dispatch.h
Added ShareMemHandles to dispatch macros                                 
+2/-1     
executor.h
Added ShareMemHandles handler declaration                               
+3/-0     
host_ir.h
Added ShareMemHandles class declaration                                   
+25/-0   
communicator.h
Added TCP store access method                                                       
+4/-0     
ipc_handle.h
Defined IPC handle classes and cache                                         
+163/-0 
CMakeLists.txt
Added ipc_handle.cpp to build                                                       
+1/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 No relevant tests
⚡ Recommended focus areas for review

Error Handling

The code does not handle potential errors from CUDA API calls, such as cudaIpcGetMemHandle and cudaMalloc. It would be beneficial to add error handling to ensure that the program can gracefully handle failures.

NVFUSER_CUDA_RT_SAFE_CALL(
    cudaIpcGetMemHandle(&ipc_handle_, tensor.data_ptr()));
NVFUSER_CUDA_RT_SAFE_CALL(
    cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
static_assert(
    sizeof(IpcSemaphore) == sizeof(int),
    "IpcSemaphore must be same size as int");
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
    (void*)semaphore_, (int)IpcSemaphore::kReady, sizeof(IpcSemaphore)));
NVFUSER_CUDA_RT_SAFE_CALL(
    cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
Memory Management

The code uses cudaMalloc and cudaFree for semaphore memory allocation and deallocation. It is crucial to ensure that all allocated memory is properly freed to avoid memory leaks.

      cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
  static_assert(
      sizeof(IpcSemaphore) == sizeof(int),
      "IpcSemaphore must be same size as int");
  NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
      (void*)semaphore_, (int)IpcSemaphore::kReady, sizeof(IpcSemaphore)));
  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
}

IpcHandle::IpcHandle(std::vector<uint8_t> data) {
  const IpcHandle& imported_buffer = fromBytes<IpcHandle>(data);

  storage_offset_ = imported_buffer.storage_offset_;
  element_size_ = imported_buffer.element_size_;
  ipc_handle_ = imported_buffer.ipc_handle_;
  semaphore_ipc_handle_ = imported_buffer.semaphore_ipc_handle_;
  rank_ = imported_buffer.rank_;

  NVFUSER_CUDA_RT_SAFE_CALL(
      cudaIpcOpenMemHandle(&ptr_, ipc_handle_, cudaIpcMemLazyEnablePeerAccess));
  ptr_ = (void*)((uint8_t*)ptr_ + storage_offset_ * element_size_);

  NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
      (void**)&semaphore_,
      semaphore_ipc_handle_,
      cudaIpcMemLazyEnablePeerAccess));
}

IpcHandle::~IpcHandle() {
  if (rank_ == Communicator::getInstance().deviceId()) {
    NVFUSER_CUDA_RT_SAFE_CALL(cudaFree((void*)semaphore_));
  } else {
Performance Considerations

The code uses a barrier to synchronize all ranks after pushing their memory handles to the store. This can be a performance bottleneck. It would be beneficial to investigate more efficient synchronization mechanisms or selectively synchronize only the necessary ranks.

// barrier to ensure all ranks have pushed their memhandles to the store
// TODO: precisely select what ranks need to wait on that barrier.
communicator->barrier();

@samnordmann samnordmann changed the title Ipc handle infra [CudaIpc 2/3]: Ipc handle exchange Feb 17, 2025
@samnordmann
Copy link
Collaborator Author

!test

storage_offset_(tensor.storage_offset()),
element_size_(tensor.element_size()),
rank_(Communicator::getInstance().deviceId()) {
NVFUSER_CUDA_RT_SAFE_CALL(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that the tensor is not strided


std::unordered_map<KeyType, std::unique_ptr<P2pIpcHandle>, KeyHash, KeyEqual>
handles_;
std::unordered_set<std::string> keys_;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove (unnecessary)

}

private:
using KeyType = std::tuple<int64_t, at::Tensor, P2PCommunication*>;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't need P2PCommunication* here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually need it in the following case:
rank 0 sends a buffer to rank 1's buffer1
and concurrently ,
rank 0 sends the same buffer to rank 1's buffer2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant