forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_to_space.cl
32 lines (29 loc) · 1.99 KB
/
batch_to_space.cl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
// Revert space_to_batch
__kernel void batch_to_space(__write_only image2d_t uInput, __read_only image2d_t uOutput,
__private const int4 inImageSize, __private const int4 outImgSize,
__private const int2 padding, __private const int2 blockShape) {
int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2));
if (pos.x < outImgSize.x && pos.y < outImgSize.y) {
// pos.x -> w, pos.y -> h, pos.z -> c4 * b;
int outBatchIndex = pos.z / outImgSize.z;
int outChannelIndex = pos.z % outImgSize.z;
int inBatchIndex = outBatchIndex % inImageSize.w;
int sw = (outBatchIndex / inImageSize.w) % blockShape.y;
int sh = (outBatchIndex / inImageSize.w) / blockShape.y;
int validHeightStart = max(0, ((padding.x - sh + blockShape.x - 1) / blockShape.x));
int validHeightEnd = min(outImgSize.y, ((inImageSize.y + padding.x - sh + blockShape.x - 1) / blockShape.x));
int validWidthStart = max(0, ((padding.y - sw + blockShape.y - 1) / blockShape.y));
int validWidthEnd = min(outImgSize.x, ((inImageSize.x + padding.y - sw + blockShape.y - 1) / blockShape.y));
int inPosX = pos.x * blockShape.y + sw - padding.y;
int inPosY = pos.y * blockShape.x + sh - padding.x;
int inPosZ = inBatchIndex * inImageSize.z + outChannelIndex;
int inputX = select(inPosX + inPosZ * inImageSize.x, -1, pos.x < validWidthStart || pos.x >= validWidthEnd);
int inputY =
select(inPosY + inBatchIndex * inImageSize.y, -1, pos.y < validHeightStart || pos.y >= validHeightEnd);
FLOAT4 res = RI_F(
uOutput, SAMPLER, (int2)(pos.x + outChannelIndex * outImgSize.x, pos.y + outBatchIndex * outImgSize.y));
WI_F(uInput, (int2)(inputX, inputY), res);
}
}