Skip to content

Commit

Permalink
[SymmetricMemory] improve multicast initialization/fallback logic (py…
Browse files Browse the repository at this point in the history
…torch#136577)

Fixes pytorch#136494

Currently, CUDASymmetricMemory::rendezvous() initializes a multicast address if multicast support is present. However, if we believe multicast support is present but cuMulticastCreate still fails for some reason, we do not fallback gracefully.

- In addition to CUDART and driver version check, query CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED to determine multicast support for a rank/device.
- Before initializing multicast for a block, ensure all ranks/devices have multicast support.
- This is unlikely, but if cuMulticastCreate still fails on rank 0, print the corresponding driver error message as a warning, and gracefully skip multicast initialization for the block.
- Introduced an environment variable (TORCH_SYMM_MEM_DISABLE_MULTICAST) to allow users to explicitly disable multicast support as a workaround.

Pull Request resolved: pytorch#136577
Approved by: https://github.com/Chillee, https://github.com/eqy

(cherry picked from commit d55eef5)
  • Loading branch information
yifuwang committed Sep 27, 2024
1 parent bc421d4 commit 14ff5fa
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 51 deletions.
1 change: 1 addition & 0 deletions c10/cuda/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
} while (0)

#define C10_LIBCUDA_DRIVER_API(_) \
_(cuDeviceGetAttribute) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_symmetric_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def requires_cuda_p2p_access():
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
)
return skip_but_pass_in_sandcastle_if(
not has_multicast_support,
Expand Down
175 changes: 130 additions & 45 deletions torch/csrc/distributed/c10d/CUDASymmetricMemory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@

namespace {

bool has_multicast_support() {
bool device_has_multicast_support(int device_idx) {
#if defined(CUDART_SUPPORTS_MULTICAST)
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
return false;
}
// Multicast support requirements:
// - CUDA Runtime version >= 12030: Checked at compile time using
// CUDART_VERSION.
// - Driver version >= 535: Checked at runtime by verifying the existence of
// cuMulticastCreate_.
// - Device support: Determined by querying
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
auto driver_api = c10::cuda::DriverAPI::get();
int multicast_supported;
C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
&multicast_supported,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
device_idx));
return driver_api->cuMulticastCreate_ != nullptr && multicast_supported;
#else
return false;
#endif
Expand Down Expand Up @@ -70,7 +86,16 @@ class IpcChannel {
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));

if (fd != -1) {
// memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
std::copy(
reinterpret_cast<const char*>(&fd),
reinterpret_cast<const char*>(&fd) + sizeof(fd),
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
} else {
msg.msg_controllen = 0;
}

TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
Expand All @@ -94,6 +119,10 @@ class IpcChannel {
"Failed to receive fd: ",
strerror(errno));

if (msg.msg_controllen == 0) {
return -1;
}

auto cmsg = CMSG_FIRSTHDR(&msg);
TORCH_CHECK(cmsg != NULL);
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
Expand Down Expand Up @@ -319,7 +348,7 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
}

bool CUDASymmetricMemory::has_multicast_support() {
return ::has_multicast_support();
return mc_addr_ != nullptr;
}

void* CUDASymmetricMemory::get_multicast_ptr() {
Expand Down Expand Up @@ -555,10 +584,11 @@ struct RendezvousRequest {
size_t block_size;
size_t buffer_size;
size_t signal_pad_offset;
bool has_multicast_support;
};

void validate_rendezvous_requests(
const std::vector<RendezvousRequest> reqs,
const std::vector<RendezvousRequest>& reqs,
int world_size) {
TORCH_CHECK(reqs.size() == (size_t)world_size);

Expand All @@ -582,6 +612,92 @@ void validate_rendezvous_requests(
}
}

static bool check_group_multicast_support(
const std::vector<RendezvousRequest>& reqs) {
std::vector<size_t> ranks_with_multicast_support;
for (size_t r = 0; r < reqs.size(); ++r) {
if (reqs[r].has_multicast_support) {
ranks_with_multicast_support.push_back(r);
}
}
if (ranks_with_multicast_support.size() == reqs.size()) {
return true;
} else {
// We don't expect this to happen. But we want to let the user to know if
// this happens.
if (ranks_with_multicast_support.size() != 0) {
LOG(WARNING)
<< "Only a subset of ranks in the group has multicast support: "
<< ranks_with_multicast_support << " (world_size=" << reqs.size()
<< "). Skipping multicast initialization because this is unexpected.";
}
return false;
}
}

static void init_multicast_for_block(
HandleType& mc_handle,
void*& mc_addr,
const c10::intrusive_ptr<Block>& block,
IpcChannel& ipc_channel,
const std::vector<int>& pids,
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int world_size) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \
defined(CUDART_SUPPORTS_MULTICAST)
auto driver_api = c10::cuda::DriverAPI::get();
if (rank == 0) {
CUmulticastObjectProp mc_prop{};
mc_prop.numDevices = world_size;
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
mc_prop.size = block->block_size;

auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
if (err != CUDA_SUCCESS) {
const char* err_str;
CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str);
if (get_error_str_err != CUDA_SUCCESS) {
err_str = "unknown cuda driver error";
}
LOG(WARNING)
<< "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str
<< "\". Gracefully skipping multicast initialization. "
<< "However, this is unexpected. Please report the issue on GitHub.";
// Allow peers gracefully skip multicast initialization by sending -1
ipc_channel.broadcast_fds(rank, 0, pids, -1);
return;
}

int mc_fd;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
// Ref count is incremented as soon as SCM_RIGHTS send happens
close(mc_fd);
} else {
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
if (mc_fd == -1) {
return;
}
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
&mc_handle,
(void*)(uintptr_t)mc_fd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(mc_fd);
}

// All rank adds their physical allocation to the multicast object
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
mc_handle, 0, block->handle, 0, block->block_size, 0));

map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
#endif
}

c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
void* ptr) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
Expand Down Expand Up @@ -610,7 +726,8 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
.pid = getpid(),
.block_size = block->block_size,
.buffer_size = block->buffer_size,
.signal_pad_offset = block->signal_pad_offset};
.signal_pad_offset = block->signal_pad_offset,
.has_multicast_support = device_has_multicast_support(block->device_idx)};
auto reqs = store_all_gather(store, rank, world_size, local_req);
validate_rendezvous_requests(reqs, world_size);

Expand Down Expand Up @@ -642,45 +759,13 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
store_barrier(store, rank, world_size);
close(block_fd);

CUmemGenericAllocationHandle mc_handle{};
HandleType mc_handle{};
void* mc_addr = nullptr;
#if defined(CUDART_SUPPORTS_MULTICAST)
// We have to further check if the driver supports multicast
if (has_multicast_support()) {
// Rank 0 creates a multicast object and share it with peers
if (rank == 0) {
CUmulticastObjectProp mc_prop{};
mc_prop.numDevices = world_size;
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
mc_prop.size = block->block_size;

CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
TORCH_CHECK(res == CUDA_SUCCESS);

int mc_fd;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
// Ref count is incremented as soon as SCM_RIGHTS send happens
close(mc_fd);
} else {
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
&mc_handle,
(void*)(uintptr_t)mc_fd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(mc_fd);
}
// All rank adds their physical allocation to the multicast object
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
mc_handle, 0, block->handle, 0, block->block_size, 0));

map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
bool group_has_multicast_support = check_group_multicast_support(reqs);
if (group_has_multicast_support) {
init_multicast_for_block(
mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
}
#endif

// Initializing CUDASymmetricMemory with an allocation transfers its
// ownership to the CUDASymmetricMemory object. So that outstanding
Expand Down Expand Up @@ -713,8 +798,8 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
return block->symm_mem != nullptr;
}

bool CUDASymmetricMemoryAllocator::has_multicast_support() {
return ::has_multicast_support();
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
return device_has_multicast_support(device_idx);
}

c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
size_t get_alloc_size(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
bool is_rendezvous_completed(void* ptr) override;
bool has_multicast_support() override;
bool has_multicast_support(int device_idx) override;

private:
c10::intrusive_ptr<Block> find_block(void* ptr);
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/distributed/c10d/SymmetricMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
return allocator->rendezvous(tensor.data_ptr());
}

TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
TORCH_API bool has_multicast_support(
c10::DeviceType device_type,
int device_idx) {
auto allocator = get_allocator(device_type);
return allocator->has_multicast_support();
return allocator->has_multicast_support(device_idx);
}
} // namespace symmetric_memory
} // namespace c10d
6 changes: 4 additions & 2 deletions torch/csrc/distributed/c10d/SymmetricMemory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
virtual size_t get_alloc_size(void* ptr) = 0;
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
virtual bool is_rendezvous_completed(void* ptr) = 0;
virtual bool has_multicast_support() = 0;
virtual bool has_multicast_support(int device_idx) = 0;
};

C10_EXPORT bool is_finalizing();
Expand Down Expand Up @@ -154,6 +154,8 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);

TORCH_API bool has_multicast_support(c10::DeviceType device_type);
TORCH_API bool has_multicast_support(
c10::DeviceType device_type,
int device_idx);
} // namespace symmetric_memory
} // namespace c10d

0 comments on commit 14ff5fa

Please sign in to comment.