Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-hipGraph MSCCL++ tests for allReduce and allGather #1503

Merged
merged 14 commits into from
Feb 4, 2025
Merged
39 changes: 39 additions & 0 deletions test/AllGatherTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* See LICENSE.txt for license information
************************************************************************/
#include "TestBed.hpp"
#include "CallCollectiveForked.hpp"

namespace RcclUnitTesting
{
Expand Down Expand Up @@ -120,4 +121,42 @@ namespace RcclUnitTesting
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}

TEST(AllGather, UserBufferRegistration)
{
const int nranks = 8;
size_t count = 2048;
std::vector<int> sendBuff(count, 0);
std::vector<int> recvBuff(count, 0);
std::vector<int> expected(nranks*count, 0);

for (int i = 0; i < count; ++i){
sendBuff[i] = i;
}

for(int r = 0; r < nranks; ++r)
for (int i = 0; i < count; ++i)
expected[r*count + i] = sendBuff[i];

callCollectiveForked(nranks, ncclCollAllGather, sendBuff, recvBuff, expected);
}

TEST(AllGather, ManagedMemUserBufferRegistration)
{
const int nranks = 8;
size_t count = 2048;
std::vector<int> sendBuff(count, 0);
std::vector<int> recvBuff(count, 0);
std::vector<int> expected(nranks*count, 0);
const bool use_managed_mem = true;
for (int i = 0; i < count; ++i){
sendBuff[i] = i;
}

for(int r = 0; r < nranks; ++r)
for (int i = 0; i < count; ++i)
expected[r*count + i] = sendBuff[i];

callCollectiveForked(nranks, ncclCollAllGather, sendBuff, recvBuff, expected, use_managed_mem);
}
}
31 changes: 31 additions & 0 deletions test/AllReduceTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* See LICENSE.txt for license information
************************************************************************/
#include "TestBed.hpp"
#include "CallCollectiveForked.hpp"

namespace RcclUnitTesting
{
Expand Down Expand Up @@ -242,4 +243,34 @@ namespace RcclUnitTesting
}
testBed.Finalize();
}

TEST(AllReduce, UserBufferRegistration)
{
const int nranks = 8;
size_t count = 2048;
std::vector<int> sendBuff(count, 0);
std::vector<int> recvBuff(count, 0);
std::vector<int> expected(count, 0);

for (int i = 0; i < count; ++i){
sendBuff[i] = i;
expected[i] = i * nranks;
}
callCollectiveForked(nranks, ncclCollAllReduce, sendBuff, recvBuff, expected);
}

TEST(AllReduce, ManagedMemUserBufferRegistration)
{
const int nranks = 8;
size_t count = 2048;
std::vector<int> sendBuff(count, 0);
std::vector<int> recvBuff(count, 0);
std::vector<int> expected(count, 0);
const bool use_managed_mem = true;
for (int i = 0; i < count; ++i){
sendBuff[i] = i;
expected[i] = i * nranks;
}
callCollectiveForked(nranks, ncclCollAllReduce, sendBuff, recvBuff, expected, use_managed_mem);
}
}
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ if(BUILD_TESTS)
SendRecvTests.cpp
StandaloneTests.cpp
common/main.cpp
common/CallCollectiveForked.cpp
common/CollectiveArgs.cpp
common/EnvVars.cpp
common/PrepDataFuncs.cpp
common/PtrUnion.cpp
common/TestBed.cpp
common/TestBedChild.cpp
common/StandaloneUtils.cpp
)

add_executable(rccl-UnitTests ${TEST_SOURCE_FILES})
Expand Down
164 changes: 164 additions & 0 deletions test/common/CallCollectiveForked.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#include "CallCollectiveForked.hpp"
#include "CollectiveArgs.hpp"
#include <rccl/rccl.h>
#include <gtest/gtest.h>

#define HIPCALL(cmd) \
do { \
hipError_t error = (cmd); \
if (error != hipSuccess) \
{ \
printf("Encountered HIP error (%s) at line %d in file %s\n", \
hipGetErrorString(error), __LINE__, __FILE__); \
exit(-1); \
} \
} while (0)

#define NCCLCHECK(cmd) do { \
ncclResult_t res = cmd; \
if (res != ncclSuccess) { \
printf("NCCL failure %s:%d '%s'\n", \
__FILE__,__LINE__,ncclGetErrorString(res)); \
} \
} while(0)

namespace RcclUnitTesting
{

void callCollective(ncclUniqueId id, int collID, int rank, int nranks, const std::vector<int>& send, std::vector<int>& recv, bool use_managed_mem){
switch(collID){
case ncclCollAllReduce:
case ncclCollAllGather:
break;

default:
ERROR("This collective is not implemented for callCollective routine");
}

HIPCALL(hipSetDevice(rank));
hipStream_t stream;
HIPCALL(hipStreamCreate(&stream));
ncclComm_t comm;



NCCLCHECK(ncclCommInitRank(&comm, nranks, id, rank));
int *sendbuff;
int *recvbuff;
void *sendRegHandle;
void *recvRegHandle;



size_t sendSize = 0;
size_t recvSize = 0;

switch(collID){
case ncclCollAllReduce:
sendSize = send.size();
recvSize = recv.size();
break;
case ncclCollAllGather:
sendSize = send.size();
recvSize = nranks*send.size();
break;
default: exit(0);
}

if(!use_managed_mem){
HIPCALL(hipMalloc((void **)&sendbuff, sendSize * sizeof(int)));
HIPCALL(hipMalloc((void **)&recvbuff, recvSize * sizeof(int)));
}
else{
HIPCALL(hipMallocManaged((void **)&sendbuff, sendSize * sizeof(int)));
HIPCALL(hipMallocManaged((void **)&recvbuff, recvSize * sizeof(int)));
}

NCCLCHECK(ncclCommRegister(comm, sendbuff, sendSize * sizeof(int), &sendRegHandle));
NCCLCHECK(ncclCommRegister(comm, recvbuff, recvSize * sizeof(int), &recvRegHandle));

HIPCALL(hipMemcpy(sendbuff, send.data(), sizeof(int) * sendSize, hipMemcpyHostToDevice));
HIPCALL(hipMemcpy(recvbuff, recv.data(), sizeof(int) *recvSize, hipMemcpyHostToDevice));

switch(collID){
case ncclCollAllReduce:
NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, sendSize, ncclInt, ncclSum, comm, stream));
break;
case ncclCollAllGather:
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendSize, ncclInt, comm, stream));
break;
default: exit(0);
}

HIPCALL(hipStreamSynchronize(stream));
HIPCALL(hipMemcpy(recv.data(), recvbuff, sizeof(int) * recvSize, hipMemcpyDeviceToHost));

NCCLCHECK(ncclCommDeregister(comm, sendRegHandle));
NCCLCHECK(ncclCommDeregister(comm, recvRegHandle));

HIPCALL(hipFree(sendbuff));
HIPCALL(hipFree(recvbuff));
ncclCommDestroy(comm);
}

void callCollectiveForked(int nranks, int collID, const std::vector<int>& sendBuff, std::vector<int>& recvBuff, const std::vector<int>& expected, bool use_managed_mem){
std::vector<pid_t> children(nranks, 0);
std::vector<std::vector<int>> childPipes(nranks, std::vector<int>(2,0));
ncclUniqueId id;

for(int r = 0; r < nranks; ++r){
if(pipe(childPipes[r].data()) == -1)
ERROR("child %i pipe Failed\n", r);
}

auto createNCCLid = [&](int rank){
ncclGetUniqueId(&id);
close(childPipes[rank][0]);
write(childPipes[rank][1], &id, sizeof(ncclUniqueId));
close(childPipes[rank][1]);
};

auto getNCCLidFromParent = [&](int rank){
close(childPipes[rank][1]); //close write to child0
read(childPipes[rank][0], &id, sizeof(ncclUniqueId));
close(childPipes[rank][0]);
};

auto getAndDistributeNCCLid = [&](int nranks){
close(childPipes[0][1]); //close write to child0
read(childPipes[0][0], &id, sizeof(ncclUniqueId)); //read from child0
for(int r = 1; r < nranks; ++r){
write(childPipes[r][1], &id, sizeof(ncclUniqueId));
close(childPipes[r][1]);
}
};

for(int r = 0; r < nranks; ++r){
children[r] = fork();
if(children[r] == 0){
int ngpus = 0;
HIPCALL(hipGetDeviceCount(&ngpus));
if(ngpus != nranks){
exit(0);
}
//child processes
if(r == 0)
createNCCLid(r);
else
getNCCLidFromParent(r);

callCollective(id, collID, r, nranks, sendBuff, recvBuff, use_managed_mem);
for(int i = 0; i < recvBuff.size(); ++i){
ASSERT_EQ(recvBuff[i], expected[i]);
}
exit(0);
}
}

getAndDistributeNCCLid(nranks);

for(int r = 0; r < nranks; ++r)
wait(NULL); // Wait for all children
}

}
11 changes: 11 additions & 0 deletions test/common/CallCollectiveForked.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef CALLCOLLECTIVEFORKED_H
#define CALLCOLLECTIVEFORKED_H

#include <vector>

namespace RcclUnitTesting
{
void callCollectiveForked(int nranks, int collID, const std::vector<int>& sendBuff, std::vector<int>& recvBuff, const std::vector<int>& expected, bool use_managed_mem = false);
}

#endif
76 changes: 76 additions & 0 deletions test/common/StandaloneUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*************************************************************************
* Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "CollectiveArgs.hpp"
#include "StandaloneUtils.hpp"
#include <iostream>
#include <regex>


namespace RcclUnitTesting
{

std::string executeCommand(const char* cmd) {
std::string result;
FILE* pipe = popen(cmd, "r");

if (!pipe) {
std::cerr << "Error executing command: " << cmd << std::endl;
return result;
}

char buffer[128];
while (!feof(pipe)) {
if (fgets(buffer, 128, pipe) != NULL) {
result += buffer;
}
}

pclose(pipe);
return result;
}

std::vector<std::string> splitString(const std::string& str, char delimiter) {
std::vector<std::string> result;
std::istringstream iss(str);

std::string line;
while(std::getline(iss, line, delimiter)) {
result.push_back(line);
}

return result;
}


ArchInfo parseMetadata(const std::vector<std::string>& list) {
ArchInfo archInfo;
KernelInfo currKernelInfo;

std::regex amdhsaTargetRegex("amdhsa.target:\\s+(?:'?)amdgcn-amd-amdhsa--(\\w+)(?:'?)");
std::regex kernelNameRegex("\\.name:\\s+(\\w+)");
std::regex privateSegmentSizeRegex("\\.private_segment_fixed_size:\\s+(\\d+)");

for (const auto& line : list) {
std::smatch match;

if (std::regex_search(line, match, amdhsaTargetRegex)) {
archInfo.archName = match[1];
} else if (std::regex_search(line, match, kernelNameRegex)) {
currKernelInfo.name = match[1];
} else if (std::regex_search(line, match, privateSegmentSizeRegex)) {
currKernelInfo.privateSegmentFixedSize = std::stoi(match[1]);
}

if (!currKernelInfo.name.empty() && currKernelInfo.privateSegmentFixedSize != 0) {
archInfo.kernels.push_back(currKernelInfo);
currKernelInfo = {}; // Empty kernelInfo
}
}

return archInfo;
}

}
Loading