diff --git a/ext-src/read-allred.patch b/ext-src/read-allred.patch index a51f42b77..fb6583ff2 100644 --- a/ext-src/read-allred.patch +++ b/ext-src/read-allred.patch @@ -1,16 +1,16 @@ diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp -index 1b85136..ee90c2f 100644 +index 1b85136..36d85e4 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp -@@ -386,24 +386,353 @@ __global__ void __launch_bounds__(512, 1) +@@ -386,12 +386,323 @@ __global__ void __launch_bounds__(512, 1) } } +template +__global__ void __launch_bounds__(512, 1) + allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle* smChannels, -+ mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, -+ int rank, int nRanksPerNode, int worldSize, size_t nelems) { ++ mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, int rank, ++ int nRanksPerNode, int worldSize, size_t nelems) { + const int nPeer = nRanksPerNode - 1; + const size_t chanOffset = nPeer * blockIdx.x; + // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) @@ -22,7 +22,7 @@ index 1b85136..ee90c2f 100644 + int4* buff4 = reinterpret_cast(buff); + int4* resultBuff4 = reinterpret_cast(resultBuff); + -+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` ++ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` + constexpr size_t unitNInt4 = 512; + const size_t maxNInt4PerBlock = + (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; @@ -48,18 +48,18 @@ index 1b85136..ee90c2f 100644 + } + __syncwarp(); + -+ // we can use double buffering to hide synchronization overhead -+ for (size_t itr = 0; itr < nItrs; itr++) { -+ if (threadIdx.x < static_cast(nPeer)) { -+ channels[threadIdx.x].signal(); -+ channels[threadIdx.x].wait(); -+ } -+ __syncthreads(); ++ // Wait for other GPUs before reading input from channels ++ if (threadIdx.x < static_cast(nPeer)) { ++ channels[threadIdx.x].signal(); ++ channels[threadIdx.x].wait(); ++ } ++ __syncthreads(); + ++ for (size_t itr = 0; itr < nItrs; itr++) { + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; ++ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx); + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; @@ -69,27 +69,14 @@ index 1b85136..ee90c2f 100644 + data); + } + } -+ if (threadIdx.x < static_cast(nPeer)) { -+ outChannels[threadIdx.x].signal(); -+ outChannels[threadIdx.x].wait(); -+ } -+ __syncthreads(); -+ + offsetOfThisBlock += nInt4PerChunk; + } + + if (restNInt4 > 0) { -+ if (threadIdx.x < static_cast(nPeer)) { -+ channels[threadIdx.x].signal(); -+ channels[threadIdx.x].wait(); -+ -+ } -+ __syncthreads(); -+ + for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) { + int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; + for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { -+ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx);; ++ int4 val = channels[peerIdx].read(nInt4PerRank * rank + offsetOfThisBlock + idx); + data = add_vectors(val, data); + } + resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; @@ -98,14 +85,15 @@ index 1b85136..ee90c2f 100644 + data); + } + } -+ -+ if (threadIdx.x < static_cast(nPeer)) { -+ outChannels[threadIdx.x].signal(); -+ outChannels[threadIdx.x].wait(); -+ } -+ __syncthreads(); + } + ++ // Synchronize threads before signaling that all results have been written to outChannels ++ __syncthreads(); ++ if (threadIdx.x < static_cast(nPeer)) { ++ outChannels[threadIdx.x].signal(); ++ outChannels[threadIdx.x].wait(); ++ } ++ __syncthreads(); +} + +template