forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SortingKthValue.cu
248 lines (217 loc) · 7.45 KB
/
SortingKthValue.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#include <ATen/ATen.h>
#include <ATen/native/SortingUtils.h>
#include <assert.h>
#include <c10/macros/Macros.h>
#include <stdlib.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
#include <THC/THCNumerics.cuh>
#include <THC/THCScanUtils.cuh>
#include <THC/THCTensorMathReduce.cuh> // AddOp
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/extrema.h>
#include <thrust/inner_product.h>
#include <thrust/sequence.h>
#include <THC/THCThrustAllocator.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
namespace at {
namespace native {
namespace {
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherKthValue(
cuda::detail::TensorInfo<scalar_t, index_t> input,
index_t inputSliceSize,
index_t k,
index_t numInputSlices,
index_t inputWithinSliceStride,
cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
cuda::detail::TensorInfo<int64_t, index_t> indices) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of index_t
__shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit
index_t slice = getLinearBlockId<index_t>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
index_t sliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
index_t kthValueSliceStartIndex =
cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
index_t indicesSliceStartIndex =
cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
scalar_t* inputSliceStart = &input.data[sliceStartIndex];
scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
scalar_t kValue = static_cast<scalar_t>(0);
radixSelect<
scalar_t,
typename TopKTypeConfig<scalar_t>::RadixType,
index_t,
false>(
inputSliceStart,
k,
inputSliceSize,
inputWithinSliceStride,
smem,
&kValue);
// Find the index of the k-th highest element
index_t kValueIndex = 0;
bool foundKValue = false;
for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
bool isKValue = inRange && THCNumerics<scalar_t>::eq_with_nan(v, kValue);
if (isKValue) {
kValueIndex = i;
foundKValue = true;
break;
}
}
if (foundKValue) {
kthValueSliceStart[0] = kValue;
indicesSliceStart[0] = kValueIndex;
}
}
struct KthValueLauncher {
int64_t k;
KthValueLauncher(int64_t k) : k(k) {}
template <typename scalar_t, typename index_t, int all_dims>
inline void launch(
cuda::detail::TensorInfo<scalar_t, index_t> values_info,
int collapse_values_dim,
cuda::detail::TensorInfo<int64_t, index_t> indices_info,
int collapse_indices_dim,
cuda::detail::TensorInfo<scalar_t, index_t> self_info,
int collapse_self_dim,
int64_t num_slices,
int64_t slice_size) {
dim3 grid;
if (!getGridFromTiles(num_slices, grid)) {
AT_ERROR("slices are too many");
}
dim3 block(
std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024));
auto stream = at::cuda::getCurrentCUDAStream();
gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
self_info,
slice_size,
k,
num_slices,
/* The actual dimension that the k-selection is running in */
/* may have changed from collapseDims() */
self_info.strides[collapse_self_dim],
values_info,
indices_info);
}
};
template <typename scalar_t>
void kthvalue_cuda_template(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.size(dim);
// FIXME: This seems bogus, I only do this because it was the old behaviour.
// The reductions are fine, as long as the axis being reduced along
// isn't of 0 elements (and the output has elements).
TORCH_CHECK(
self.numel() > 0,
"cannot perform reduction function kthvalue",
" on tensor with no elements because the operation does not have an identity");
TORCH_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return;
}
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices)) {
run_launcher<scalar_t, uint32_t>(
values, indices, self, dim, KthValueLauncher(k));
} else {
run_launcher<scalar_t, uint64_t>(
values, indices, self, dim, KthValueLauncher(k));
}
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
AT_CUDA_CHECK(cudaGetLastError());
}
// this does not reduce to median with dim beause we don't want to copy twice
template <typename scalar_t>
Tensor median_cuda_template(const Tensor& self) {
TORCH_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
if (self.dim() == 0 && self.numel() == 1) {
return self.clone();
}
auto self_copy = self.clone().view(-1);
auto values = at::empty({1}, self.options());
auto indices = at::empty({1}, self.options().dtype(kLong));
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (cuda::detail::canUse32BitIndexMath(self) &&
cuda::detail::canUse32BitIndexMath(values) &&
cuda::detail::canUse32BitIndexMath(indices)) {
run_launcher<scalar_t, uint32_t>(
values,
indices,
self_copy,
0,
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
} else {
run_launcher<scalar_t, uint64_t>(
values,
indices,
self_copy,
0,
KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
}
return values.view({});
}
} // namespace
std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
});
return std::forward_as_tuple(values, indices);
}
Tensor median_cuda(const Tensor& self) {
return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "median", [&] {
return median_cuda_template<scalar_t>(self);
});
}
} // namespace native
} // namespace at