-
Notifications
You must be signed in to change notification settings - Fork 0
/
par-sum.cu
276 lines (225 loc) · 8.79 KB
/
par-sum.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#include<iostream>
#include<chrono>
#include"helper.h"
#define size 4096*4096
// GPU serial addition
__global__
void kernelV1(float* A, float* result){
int tx = blockIdx.x * blockDim.x + threadIdx.x;
atomicAdd(result, A[tx]);
}
// GPU parallel reduction and using shared memory
__global__
void kernelV2(float* A, float* result){
extern __shared__ float sA[];
int tId = blockIdx.x * blockDim.x + threadIdx.x;
int tx = threadIdx.x;
int bx = blockDim.x;
sA[tx] = A[tId];
__syncthreads();
// now perform a parallel reduction until active threads (bx) reduce to 1
while(bx > 1){
bx /= 2;
if(tx < bx) {
sA[tx] += sA[tx + bx];
}
__syncthreads();
}
// add the partial result saved in every shared memory 0th position to the final result
if(tx == 0){
atomicAdd(result, sA[0]);
}
}
// Reduce (Halve) the number of blocks and perform addition during first load
__global__
void kernelV3(float* A, float* result){
extern __shared__ float sA[];
int tId = blockIdx.x * blockDim.x + threadIdx.x;
int tx = threadIdx.x;
int bx = blockDim.x;
// perform an addition with elements at an offset and load into shared memory
sA[tx] = A[tId] + A[gridDim.x * blockDim.x + tId];
__syncthreads();
// now perform a parallel reduction until active threads (bx) reduce to 1
while(bx > 1){
bx /= 2;
if(tx < bx) {
sA[tx] += sA[tx + bx];
}
__syncthreads();
}
// add the partial result saved in every shared memory 0th position to the final result
if(tx == 0){
atomicAdd(result, sA[0]);
}
}
// Even fewer blocks; do multiple adds during first shared mem load
__global__
void kernelV4(float* A, float* result){
extern __shared__ float sA[];
int tId = blockIdx.x * blockDim.x + threadIdx.x;
int tx = threadIdx.x;
int bx = blockDim.x;
// perform an addition with multiple elements at offsets and load into shared memory
sA[tx] = A[tId] + A[gridDim.x * blockDim.x + tId]
+ A[2 * gridDim.x * blockDim.x + tId] + A[3 * gridDim.x * blockDim.x + tId]
+ A[4 * gridDim.x * blockDim.x + tId] + A[5 * gridDim.x * blockDim.x + tId]
+ A[6 * gridDim.x * blockDim.x + tId] + A[7 * gridDim.x * blockDim.x + tId];
__syncthreads();
// parallel reduction
while(bx > 8){
bx >>= 3; // log base 8 time complexity
if(tx < bx) {
sA[tx] += (sA[tx + bx] + sA[tx + 2*bx] + sA[tx + 3*bx] + sA[tx + 4*bx] +
sA[tx + 5*bx] + sA[tx + 6*bx] + sA[tx + 7*bx]);
}
__syncthreads();
}
// parallel reduction - final 8 elements would be left out from previous operation
while(bx > 1){
bx >>= 1; // log base 2
if(tx < bx) {
sA[tx] += sA[tx + bx];
}
__syncthreads();
}
if(tx == 0){
atomicAdd(result, sA[0]);
}
}
// within warp level, activities are synchronous and do not need explicit synchronization
__device__
void warpReduce(volatile float* sA, int tx){
sA[tx] += sA[tx + 16];
sA[tx] += sA[tx + 8];
sA[tx] += sA[tx + 4];
sA[tx] += sA[tx + 2];
sA[tx] += sA[tx + 1];
}
// remove explicit synchronization with in warp (final 32 elements)
__global__
void kernelV5(float* A, float* result){
extern __shared__ float sA[];
int tId = blockIdx.x * blockDim.x + threadIdx.x;
int tx = threadIdx.x;
int bx = blockDim.x;
// perform an addition with multiple elements at offsets and load into shared memory
sA[tx] = A[tId] + A[gridDim.x * blockDim.x + tId]
+ A[2 * gridDim.x * blockDim.x + tId] + A[3 * gridDim.x * blockDim.x + tId]
+ A[4 * gridDim.x * blockDim.x + tId] + A[5 * gridDim.x * blockDim.x + tId]
+ A[6 * gridDim.x * blockDim.x + tId] + A[7 * gridDim.x * blockDim.x + tId];
__syncthreads();
// parallel reduction
while(bx > 32){
bx >>= 3; // log base 8 time complexity
if(tx < bx) {
sA[tx] += (sA[tx + bx] + sA[tx + 2*bx] + sA[tx + 3*bx] + sA[tx + 4*bx] +
sA[tx + 5*bx] + sA[tx + 6*bx] + sA[tx + 7*bx]);
}
__syncthreads();
}
// parallel reduction - final 32 elements fitting in the warp
warpReduce(sA, tx);
if(tx == 0){
atomicAdd(result, sA[0]);
}
}
__device__
float4 float4Sum(float4& v0, float4& v1, float4& v2, float4& v3, float4& v4, float4& v5, float4& v6, float4& v7){
return {v0.w+v1.w+v2.w+v3.w+v4.w+v5.w+v6.w+v7.w,
v0.x+v1.x+v2.x+v3.x+v4.x+v5.x+v6.x+v7.x,
v0.y+v1.y+v2.y+v3.y+v4.y+v5.y+v6.y+v7.y,
v0.z+v1.z+v2.z+v3.z+v4.z+v5.z+v6.z+v7.z};
}
// kernelV4 with FLOAT4
__global__
void kernelV6(float* __restrict__ A, float* __restrict__ result){
extern __shared__ float sA[];
int tId = blockIdx.x * blockDim.x + threadIdx.x;
int tx = threadIdx.x;
int bx = blockDim.x;
// // perform an addition with multiple elements at offsets and load into shared memory
sA[tx] = A[tId] + A[gridDim.x * blockDim.x + tId]
+ A[2 * gridDim.x * blockDim.x + tId] + A[3 * gridDim.x * blockDim.x + tId]
+ A[4 * gridDim.x * blockDim.x + tId] + A[5 * gridDim.x * blockDim.x + tId]
+ A[6 * gridDim.x * blockDim.x + tId] + A[7 * gridDim.x * blockDim.x + tId];
__syncthreads();
// parallel reduction
while(bx >= 32){
bx >>= 3;
if(tx % 4 == 0 && tx < bx) {
int offset = bx/4, tx_4 = tx / 4;
float4 v0 = reinterpret_cast<float4*>(sA)[tx_4];
float4 v1 = reinterpret_cast<float4*>(sA)[tx_4 + offset];
float4 v2 = reinterpret_cast<float4*>(sA)[tx_4 + 2*offset];
float4 v3 = reinterpret_cast<float4*>(sA)[tx_4 + 3*offset];
float4 v4 = reinterpret_cast<float4*>(sA)[tx_4 + 4*offset];
float4 v5 = reinterpret_cast<float4*>(sA)[tx_4 + 5*offset];
float4 v6 = reinterpret_cast<float4*>(sA)[tx_4 + 6*offset];
float4 v7 = reinterpret_cast<float4*>(sA)[tx_4 + 7*offset];
float4* resultPtr = reinterpret_cast<float4*>(&sA[tx]);
*resultPtr = float4Sum(v0, v1, v2, v3, v4, v5, v6, v7);
}
__syncthreads();
}
while(bx > 1){
bx >>= 1; // log base 2
if(tx < bx) {
sA[tx] += sA[tx + bx];
}
__syncthreads();
}
if(tx == 0){
atomicAdd(result, sA[0]);
}
}
int main(){
int N = size;
// host memory allocation
float *hA = (float *) malloc(N * sizeof(float));
float hSumFromGPU, hSumFromCUBLAS;
// device memory allocation
float *A = (float *)fixed_cudaMalloc(N * sizeof(float));
float *result = (float *)fixed_cudaMalloc(sizeof(float));
cudaMemset(A, 0., N * sizeof(float));
cudaMemset(&result, 0., sizeof(float));
// host memory initialization
srand (static_cast <unsigned> (time(0))); // seed for random initialization
intializeMatrix(hA, N);
// copy host to device memory
gpuErrchk(cudaMemcpy(A, hA, N*sizeof(float), cudaMemcpyHostToDevice));
// CPU compute of sum
float hSumFromCPU = computeSum(hA, N); // NOTE: hA gets modified due to parallel sum but unused henceforth
// // kernel 1
// kernelV1<<<size/256, 256>>>(A, result);
// // kernel 2
// int NUM_THREADS = 256;
// kernelV2<<<CEIL_DIV(size, NUM_THREADS), NUM_THREADS, NUM_THREADS*sizeof(float)>>>(A, result);
// // Kernel 3
// int NUM_THREADS = 2;
// kernelV3<<<CEIL_DIV(size, (NUM_THREADS * 2)), NUM_THREADS, NUM_THREADS*sizeof(float)>>>(A, result);
// // Kernel 4 - kernel3 with fewer blocks
// int NUM_THREADS = 256;
// kernelV4<<<CEIL_DIV(size, (NUM_THREADS * 8)), NUM_THREADS, NUM_THREADS*sizeof(float)>>>(A, result);
// // Kernel 5
// int NUM_THREADS = 256;
// kernelV5<<<CEIL_DIV(size, (NUM_THREADS * 8)), NUM_THREADS, NUM_THREADS*sizeof(float)>>>(A, result);
// Kernel 6
int NUM_THREADS = 128;
kernelV6<<<CEIL_DIV(size, (NUM_THREADS*8)), NUM_THREADS, NUM_THREADS*sizeof(float)>>>(A, result);
// copy device to host memory
gpuErrchk(cudaMemcpy(&hSumFromGPU, result, sizeof(float), cudaMemcpyDeviceToHost));
// cuBLAS sum function
cuBLASSUM(A, N, &hSumFromCUBLAS);
float epsilon = 1e-6f;
cout << "Result from CPU : " << hSumFromCPU << endl;
cout << "Result from GPU : " << hSumFromGPU << endl;
cout << "Result from cuBLAS : " << hSumFromCUBLAS << endl;
cout << "Absolute difference : " << fabs(hSumFromGPU - hSumFromCUBLAS) << endl;
cout << "Approximately equal? : " << (boolalpha) << approximatelyEqual(hSumFromGPU, hSumFromCUBLAS, epsilon)
<< " (epsilon = " << epsilon << ")" << endl;
// Free the memory
cudaFree(A);
free(hA);
return 0;
}