forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorScatterGather.cu
199 lines (177 loc) · 7.2 KB
/
THCTensorScatterGather.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCApply.cuh>
// Compute the offsets into the given tensors for a linear index. For the 't2'
// tensor, dimension 'dim' is skipped. The tensors are assumed to have the same
// size (with the exception of 't2' in dimension 'dim').
// This version uses a static number of dimensions.
template <typename IndexType, typename Real, int Dims>
struct IndexToScatterGatherOffsets {
static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t1, IndexType* t1Offset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = Dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}
static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = Dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}
};
// Same as above but using a dynamic number of dimensions.
template <typename IndexType, typename Real>
struct IndexToScatterGatherOffsets<IndexType, Real, -1> {
static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t1, IndexType* t1Offset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = index.dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
*t1Offset += curDimIndex * t1.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}
static __device__ void compute(
IndexType linearId, const int dim,
const TensorInfo<int64_t, IndexType>& index, IndexType* indexOffset,
const TensorInfo<Real, IndexType>& t2, IndexType* t2Offset) {
for (int d = index.dims - 1; d >= 0; d--) {
IndexType curDimIndex = linearId % index.sizes[d];
*indexOffset += curDimIndex * index.strides[d];
if (d != dim) {
*t2Offset += curDimIndex * t2.strides[d];
}
linearId /= index.sizes[d];
}
}
};
template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_gatherKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<Real, IndexType> src,
TensorInfo<int64_t, IndexType> index,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType srcOffset = 0;
IndexType indexOffset = 0;
IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
tensor, &tensorOffset,
src, &srcOffset);
int64_t indexValue = index.data[indexOffset];
assert(indexValue >= 0 && indexValue < src.sizes[dim]);
srcOffset += indexValue * src.strides[dim];
tensor.data[tensorOffset] = src.data[srcOffset];
}
}
template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<Real, IndexType> src,
TensorInfo<int64_t, IndexType> index,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType srcOffset = 0;
IndexType indexOffset = 0;
IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
src, &srcOffset,
tensor, &tensorOffset);
int64_t indexValue = index.data[indexOffset];
assert(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];
tensor.data[tensorOffset] = src.data[srcOffset];
}
}
template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterAddKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<Real, IndexType> src,
TensorInfo<int64_t, IndexType> index,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType srcOffset = 0;
IndexType indexOffset = 0;
IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
src, &srcOffset,
tensor, &tensorOffset);
int64_t indexValue = index.data[indexOffset];
assert(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];
atomicAdd(&tensor.data[tensorOffset], src.data[srcOffset]);
}
}
template <typename IndexType, typename Real, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void THCudaTensor_scatterFillKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<int64_t, IndexType> index,
Real value,
const int dim,
const IndexType totalElements) {
for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
linearId < totalElements;
linearId += gridDim.x * blockDim.x) {
IndexType tensorOffset = 0;
IndexType indexOffset = 0;
IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
index, &indexOffset,
tensor, &tensorOffset);
int64_t indexValue = index.data[indexOffset];
assert(indexValue >= 0 && indexValue < tensor.sizes[dim]);
tensorOffset += indexValue * tensor.strides[dim];
tensor.data[tensorOffset] = value;
}
}
#include <THC/generic/THCTensorScatterGather.cu>
#include <THC/THCGenerateAllTypes.h>
#include <THC/generic/THCTensorScatterGather.cu>
#include <THC/THCGenerateBoolType.h>