Skip to content

Commit

Permalink
Merge pull request #58 from InfiniTensor/dev
Browse files Browse the repository at this point in the history
perf(kernel): 使用 __ldg 优化性能
  • Loading branch information
YdrMaster authored Dec 14, 2023
2 parents de3b474 + 59939c6 commit d5e1b02
Show file tree
Hide file tree
Showing 20 changed files with 273 additions and 283 deletions.
10 changes: 0 additions & 10 deletions src/04kernel/cuda/include/kernel/cuda/bench.cuh

This file was deleted.

34 changes: 0 additions & 34 deletions src/04kernel/cuda/src/bench.cu

This file was deleted.

2 changes: 1 addition & 1 deletion src/04kernel/cuda/src/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace refactor::kernel::cuda {
concatKernel<<<
params.gridSize,
params.blockSize,
inputCount *(sizeof(unsigned int) + sizeof(void *)),
inputCount * (sizeof(unsigned int) + sizeof(void *)),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const **>(inputs),
Expand Down
18 changes: 8 additions & 10 deletions src/04kernel/cuda/src/expand.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@ namespace refactor::kernel::cuda {

__global__ static void expandKernel(
unsigned long long n,
uint8_t const *data, expand::DimStride const *strides, uint8_t *output,
uint8_t const *__restrict__ data,
expand::DimStride const *__restrict__ strides,
uint8_t *__restrict__ output,
unsigned int rank,
unsigned int eleSize) {
extern __shared__ expand::DimStride shared[];
for (auto i = threadIdx.x; i < rank; i += blockDim.x) {
shared[i] = strides[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, i = 0;
for (auto j = 0; j < rank; ++j) {
auto s = shared[j];
i += rem / s.o * s.i;
rem %= s.o;
auto o_ = __ldg(&(strides[j].o));
auto i_ = __ldg(&(strides[j].i));
i += rem / o_ * i_;
rem %= o_;
}
optimizedMemcpy(output + tid * eleSize, data + i * eleSize, eleSize);
}
Expand All @@ -37,7 +35,7 @@ namespace refactor::kernel::cuda {
expandKernel<<<
params.gridSize,
params.blockSize,
rank * sizeof(expand::DimStride),
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
20 changes: 8 additions & 12 deletions src/04kernel/cuda/src/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,30 @@
namespace refactor::kernel::cuda {

template<class index_t>
__global__ void gatherKernel(
__global__ static void gatherKernel(
unsigned long long n,
uint8_t const *data,
index_t const *indices,
uint8_t *output,
uint8_t const *__restrict__ data,
index_t const *__restrict__ indices,
uint8_t *__restrict__ output,
unsigned int batch,
unsigned int unit,
unsigned int midSizeI,
unsigned int midSizeO) {
extern __shared__ uint32_t shared[];
for (auto i = threadIdx.x; i < midSizeO; i += blockDim.x) {
shared[i] = indices[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
auto i = tid / batch,
j = tid % batch;
auto index = __ldg(indices + i % midSizeO);
optimizedMemcpy(unit * tid + output,
unit * (batch * (i / midSizeO * midSizeI + shared[i % midSizeO]) + j) + data,
unit * (batch * (i / midSizeO * midSizeI + index) + j) + data,
unit);
}
}

template<class index_t>
void launchGather(
void static launchGather(
KernelLaunchParameters const &params,
void const *data, void const *indices, void *output,
unsigned int batch,
Expand All @@ -42,7 +38,7 @@ namespace refactor::kernel::cuda {
gatherKernel<<<
params.gridSize,
params.blockSize,
midSizeO * sizeof(uint32_t),
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
10 changes: 5 additions & 5 deletions src/04kernel/cuda/src/scatter_nd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ namespace refactor::kernel::cuda {

__global__ void scatterNDKernel(
size_t n,
uint8_t *out,
uint8_t const *in,
int64_t const *indices,
unsigned int const *strides,
uint8_t *__restrict__ out,
uint8_t const *__restrict__ in,
int64_t const *__restrict__ indices,
unsigned int const *__restrict__ strides,
size_t rank,
size_t blockSize) {
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
Expand All @@ -19,7 +19,7 @@ namespace refactor::kernel::cuda {
unsigned int j = 0;
auto i = indices + tid * rank;
for (auto k = 0; k < rank; ++k) {
j += i[k] * strides[k];
j += i[k] * __ldg(strides + k);
}
optimizedMemcpy(out + j * blockSize,
in + tid * blockSize,
Expand Down
19 changes: 9 additions & 10 deletions src/04kernel/cuda/src/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,22 @@ namespace refactor::kernel::cuda {

__global__ static void sliceKernel(
unsigned long long n,
uint8_t const *src, DimInfo const *dims, uint8_t *dst,
uint8_t const *__restrict__ src,
DimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
extern __shared__ DimInfo dimInfo[];
for (auto i = threadIdx.x; i < rank; i += blockDim.x) {
dimInfo[i] = dims[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, j = 0;
for (auto i = 0; i < rank; ++i) {
auto const &dim = dimInfo[i];
j += rem / dim.strideO * dim.strideI + dim.skip;
rem %= dim.strideO;
auto strideO = __ldg(&(dims[i].strideO));
auto strideI = __ldg(&(dims[i].strideI));
auto skip = __ldg(&(dims[i].skip));
j += rem / strideO * strideI + skip;
rem %= strideO;
}
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
}
Expand All @@ -36,7 +35,7 @@ namespace refactor::kernel::cuda {
sliceKernel<<<
params.gridSize,
params.blockSize,
rank * sizeof(DimInfo),
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(src),
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/cuda/src/split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace refactor::kernel::cuda {
splitKernel<<<
params.gridSize,
params.blockSize,
outputCount *(sizeof(unsigned int) + sizeof(void *)),
outputCount * (sizeof(unsigned int) + sizeof(void *)),
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
18 changes: 8 additions & 10 deletions src/04kernel/cuda/src/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@ namespace refactor::kernel::cuda {

__global__ static void transposeKernel(
unsigned long long n,
uint8_t const *data, transpose::DimStride const *strides, uint8_t *output,
uint8_t const *__restrict__ data,
transpose::DimStride const *__restrict__ strides,
uint8_t *__restrict__ output,
unsigned int rank,
unsigned int eleSize) {
extern __shared__ transpose::DimStride shared[];
for (auto i = threadIdx.x; i < rank; i += blockDim.x) {
shared[i] = strides[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
auto j = 0u, rem = tid;
for (auto k = 0u; k < rank; ++k) {
auto d = shared[k];
j += rem / d.o * d.i;
rem %= d.o;
auto o_ = __ldg(&(strides[k].o));
auto i_ = __ldg(&(strides[k].i));
j += rem / o_ * i_;
rem %= o_;
}

optimizedMemcpy(output + tid * eleSize, data + j * eleSize, eleSize);
Expand All @@ -37,7 +35,7 @@ namespace refactor::kernel::cuda {
transposeKernel<<<
params.gridSize,
params.blockSize,
rank * sizeof(transpose::DimStride),
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
reinterpret_cast<uint8_t const *>(data),
Expand Down
29 changes: 12 additions & 17 deletions src/04kernel/cuda/src/where.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,25 @@ namespace refactor::kernel::cuda {

__global__ static void whereKernel(
unsigned long long n,
unsigned int const *strides,
bool const *c,
uint8_t const *x,
uint8_t const *y,
uint8_t *output,
unsigned int const *__restrict__ strides,
bool const *__restrict__ c,
uint8_t const *__restrict__ x,
uint8_t const *__restrict__ y,
uint8_t *__restrict__ output,
unsigned int rank,
unsigned int eleSize) {
extern __shared__ unsigned int shared[];
for (auto i = threadIdx.x; i < rank * 4; i += blockDim.x) {
shared[i] = strides[i];
}
__syncthreads();
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
auto ic = 0u, ix = 0u, iy = 0u, rem = tid;
for (auto j = 0u; j < rank; ++j) {
auto dim = shared + 4 * j;
auto quot = rem / dim[3];
rem %= dim[3];
ic += quot * dim[0];
ix += quot * dim[1];
iy += quot * dim[2];
auto dim = strides + 4 * j;
auto quot = rem / __ldg(dim + 3);
rem %= __ldg(dim + 3);
ic += quot * __ldg(dim + 0);
ix += quot * __ldg(dim + 1);
iy += quot * __ldg(dim + 2);
}

optimizedMemcpy(output + tid * eleSize,
Expand All @@ -52,7 +47,7 @@ namespace refactor::kernel::cuda {
whereKernel<<<
params.gridSize,
params.blockSize,
rank * sizeof(unsigned int) * 4,
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
strides,
Expand Down
37 changes: 0 additions & 37 deletions src/04kernel/cuda/test/bench.cu

This file was deleted.

26 changes: 24 additions & 2 deletions src/04kernel/src/generator/nvrtc_repo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
nvrtcGetErrorString(status))); \
}

#define CUDA_ASSERT(CALL) \
if (auto result = CALL; result != CUDA_SUCCESS) { \
const char *msg; \
cuGetErrorName(result, &msg); \
RUNTIME_ERROR(fmt::format("cuda driver failed on \"" #CALL "\" with {} ({})", \
msg, (int) result)); \
}

namespace refactor::kernel::nvrtc {

Handler::Handler(std::string_view name,
Expand Down Expand Up @@ -85,8 +93,22 @@ namespace refactor::kernel::nvrtc {
return it->second;
}

CUfunction Handler::kernel() const {
return _kernel;
void Handler::launch(unsigned int gridDimX,
unsigned int gridDimY,
unsigned int gridDimZ,
unsigned int blockDimX,
unsigned int blockDimY,
unsigned int blockDimZ,
unsigned int sharedMemBytes,
void **kernelParams) const {
CUDA_ASSERT(cuLaunchKernel(
_kernel,
gridDimX, gridDimY, gridDimZ,
blockDimX, blockDimY, blockDimZ,
sharedMemBytes,
nullptr,
kernelParams,
nullptr));
}

std::string_view memCopyType(size_t size) {
Expand Down
Loading

0 comments on commit d5e1b02

Please sign in to comment.