forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
HistogramKernel.cpp
291 lines (246 loc) · 12 KB
/
HistogramKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Histogram.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/sum.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like_ops.h>
#endif
#include <algorithm>
#include <numeric>
#include <functional>
namespace at::native {
namespace {
constexpr int64_t HISTOGRAM_GRAIN_SIZE = 200;
/* The main algorithm. Expects that the input tensor has shape (N, D).
* Expects that bin_edges contains D one-dimensional tensors, each specifying
* an increasing sequences of bin edges.
*
* Interprets the input as N different D-dimensional coordinates and maps them
* into the D-dimensional bins defined by bin_edges, accumulating a D-dimensional
* histogram in the hist tensor.
*
* Accepts a template argument of type BIN_SELECTION_ALGORITHM specifying how
* the scalars in each dimension should be mapped into the dimension's bins:
*
* - LINEAR_INTERPOLATION: each bin edge sequence must form a linear progression.
* Scalars are mapped to bins by computing
* (element - leftmost_edge)/(rightmost_edge - leftmost_edge) * bin_ct
* and truncating the result to an integer.
*
* This is the fastest option, but its results may not be perfectly consistent
* with the boundaries specified in bin_edges due to precision issues.
*
* Used by torch.histc, which doesn't need consistency with bin_edges as it does not
* return bin_edges. Additionally, this implementation is identical to the legacy histc
* implementation, which was replaced when histogram was implemented.
*
* - LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH: Also expects that each bin edge sequence
* forms a linear progression. For each scalar, if 'pos' is the bin selected by the
* LINEAR_INTERPOLATION approach, this approach inspects the boundaries in bin_edges to
* place the scalar into pos - 1, pos, or pos + 1. The "local search" over neighboring
* bins allows for correction of misclassifications due to precision issues (a scalar
* very close to a bin_edge may be misclassified by LINEAR_INTERPOLATION).
*
* Should produce the same output as the general case BINARY_SEARCH, but run about
* 3x faster asymptotically.
*
* Used by torch.histogram for cases in which bin_edges is constructed using
* torch.linspace. The behavior of LINEAR_INTERPOLATION may not perfectly align
* with linspace bin_edges due to precision issues. torch.histogram returns both
* the hist and bin_edges tensors as output, so the "local search" is needed to
* keep its output internally consistent.
*
* - BINARY_SEARCH: Handles torch.histogram's general case by by searching over the
* elements of bin_edges. Implemented using std::upper_bound.
*
* See discussion at https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866
* for further details on relative performance of the bin selection algorithms.
*/
enum BIN_SELECTION_ALGORITHM {
LINEAR_INTERPOLATION,
LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH,
BINARY_SEARCH,
};
template<typename input_t, BIN_SELECTION_ALGORITHM algorithm>
void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
const Tensor& input, const c10::optional<Tensor>& weight) {
TORCH_INTERNAL_ASSERT(input.dim() == 2);
const int64_t N = input.size(0);
if (weight.has_value()) {
TORCH_INTERNAL_ASSERT(weight.value().dim() == 1 && weight.value().numel() == N);
}
const int64_t D = input.size(1);
TORCH_INTERNAL_ASSERT(int64_t(bin_edges.size()) == D);
for (const auto dim : c10::irange(D)) {
TORCH_INTERNAL_ASSERT(bin_edges[dim].is_contiguous());
TORCH_INTERNAL_ASSERT(hist.size(dim) + 1 == bin_edges[dim].numel());
}
if (D == 0) {
// hist is an empty tensor in this case; nothing to do here
return;
}
TensorAccessor<input_t, 2> accessor_in = input.accessor<input_t, 2>();
/* Constructs a c10::optional<TensorAccessor> containing an accessor iff
* the optional weight tensor has a value.
*/
const auto accessor_wt = weight.has_value()
? c10::optional<TensorAccessor<input_t, 1>>(weight.value().accessor<input_t, 1>())
: c10::optional<TensorAccessor<input_t, 1>>();
std::vector<input_t*> bin_seq(D);
std::vector<int64_t> num_bin_edges(D);
std::vector<input_t> leftmost_edge(D), rightmost_edge(D);
for (const auto dim : c10::irange(D)) {
bin_seq[dim] = bin_edges[dim].data_ptr<input_t>();
num_bin_edges[dim] = bin_edges[dim].numel();
leftmost_edge[dim] = bin_seq[dim][0];
rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1];
}
int64_t GRAIN_SIZE = std::max(int64_t(1), HISTOGRAM_GRAIN_SIZE / D);
/* Parallelizes processing of input using at::parallel_for.
* Each thread accumulates a local result into their own slice of
* thread_histograms which get summed together at the end.
*/
const auto num_threads = at::get_num_threads();
const auto hist_sizes = hist.sizes();
DimVector thread_hist_sizes(hist_sizes.size() + 1);
thread_hist_sizes[0] = num_threads;
std::copy(hist_sizes.begin(), hist_sizes.end(),
thread_hist_sizes.begin() + 1);
Tensor thread_histograms = at::zeros(thread_hist_sizes, hist.dtype());
TORCH_INTERNAL_ASSERT(thread_histograms.is_contiguous());
at::parallel_for(0, N, GRAIN_SIZE, [&](int64_t start, int64_t end) {
const auto tid = at::get_thread_num();
auto hist_strides = thread_histograms.strides();
input_t *hist_local_data = thread_histograms.data_ptr<input_t>();
// View only this thread's local results
hist_local_data += hist_strides[0] * tid;
hist_strides = hist_strides.slice(1);
for (const auto i : c10::irange(start, end)) {
bool skip_elt = false;
int64_t hist_index = 0;
for (const auto dim : c10::irange(D)) {
const input_t elt = accessor_in[i][dim];
// Skips elements which fall outside the specified bins and NaN elements
if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
skip_elt = true;
break;
}
int64_t pos = -1;
if (algorithm == BINARY_SEARCH) {
// Handles the general case via binary search on the bin edges.
pos = std::upper_bound(bin_seq[dim], bin_seq[dim] + num_bin_edges[dim], elt)
- bin_seq[dim] - 1;
} else if (algorithm == LINEAR_INTERPOLATION
|| algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
/* When bin_edges is known to be a linear progression, maps elt to
* the appropriate bin via simple division.
*/
pos = static_cast<int64_t>((elt - leftmost_edge[dim])
* (num_bin_edges[dim] - 1)
/ (rightmost_edge[dim] - leftmost_edge[dim]));
/* Ensures consistency with bin_edges by checking the bins to the left and right
* of the selected position. Necessary for cases in which an element very close
* to a bin edge may be misclassified by simple division.
*/
if (algorithm == LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH) {
int64_t pos_min = std::max(static_cast<int64_t>(0), pos - 1);
int64_t pos_max = std::min(pos + 2, num_bin_edges[dim]);
pos = std::upper_bound(bin_seq[dim] + pos_min, bin_seq[dim] + pos_max, elt)
- bin_seq[dim] - 1;
}
} else {
TORCH_INTERNAL_ASSERT(false);
}
// Unlike other bins, the rightmost bin includes its right boundary
if (pos == (num_bin_edges[dim] - 1)) {
pos -= 1;
}
hist_index += hist_strides[dim] * pos;
}
if (!skip_elt) {
// In the unweighted case, the default weight is 1
input_t wt = accessor_wt.has_value() ? accessor_wt.value()[i] : static_cast<input_t>(1);
hist_local_data[hist_index] += wt;
}
}
});
at::sum_out(hist, thread_histograms, /*dim=*/{0});
}
/* Some pre- and post- processing steps for the main algorithm.
* Initializes hist to 0, calls into the main algorithm, and normalizes output if necessary.
*/
template<BIN_SELECTION_ALGORITHM bin_algorithm>
void histogramdd_out_cpu_template(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
Tensor& hist, const TensorList& bin_edges) {
hist.fill_(0);
const int64_t N = self.size(-1);
const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
(int64_t)1, std::multiplies<int64_t>());
const Tensor reshaped_input = self.reshape({M, N});
const auto reshaped_weight = weight.has_value()
? c10::optional<Tensor>(weight.value().reshape({M}))
: c10::optional<Tensor>();
std::vector<Tensor> bin_edges_contig(bin_edges.size());
for (const auto dim : c10::irange(bin_edges_contig.size())) {
bin_edges_contig[dim] = bin_edges[dim].contiguous();
}
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, self.scalar_type(), "histogram_cpu", [&]() {
histogramdd_cpu_contiguous<scalar_t, bin_algorithm>(
hist, bin_edges_contig, reshaped_input, reshaped_weight);
});
/* Divides each bin's value by the total count/weight in all bins,
* and by the bin's volume.
*/
if (density) {
const auto hist_sum = hist.sum().item();
hist.div_(hist_sum);
/* For each dimension, divides each bin's value
* by the bin's length in that dimension.
*/
for (const auto dim : c10::irange(N)) {
const auto bin_lengths = bin_edges[dim].diff();
// Used to reshape bin_lengths to align with the corresponding dimension of hist.
std::vector<int64_t> shape(N, 1);
shape[dim] = bin_lengths.numel();
hist.div_(bin_lengths.reshape(shape));
}
}
}
/* The general implementation of the histogram kernel. Maps each element of the input tensor
* to its corresponding bin by performing a binary search over the elements of bin_edges.
*
* Refer to histogramdd_out_cpu_template for more details.
*/
static void histogramdd_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight, bool density,
Tensor& hist, const TensorList& bin_edges) {
histogramdd_out_cpu_template<BINARY_SEARCH>(self, weight, density, hist, bin_edges);
}
/* A faster version of the histogram kernel for cases in which bin_edges are known
* to form a linear progression.
*
* Refer to histogramdd_out_cpu_template for more details.
*/
static void histogramdd_linear_kernel_impl(const Tensor& self, const c10::optional<Tensor>& weight,
bool density, Tensor& hist, const TensorList& bin_edges, bool local_search) {
if (local_search) {
// histogramdd codepath: both hist and bin_edges are eventually returned as output,
// so we'll keep them consistent
histogramdd_out_cpu_template<LINEAR_INTERPOLATION_WITH_LOCAL_SEARCH>(
self, weight, density, hist, bin_edges);
} else {
// histc codepath: bin_edges are not returned to the caller
histogramdd_out_cpu_template<LINEAR_INTERPOLATION>(
self, weight, density, hist, bin_edges);
}
}
} // namespace
REGISTER_DISPATCH(histogramdd_stub, &histogramdd_kernel_impl);
REGISTER_DISPATCH(histogramdd_linear_stub, &histogramdd_linear_kernel_impl);
} // namespace at::native