forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSampleMoreKernel.cpp
198 lines (172 loc) · 8.62 KB
/
UpSampleMoreKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <vector>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/native/UpSample.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
namespace {
using scale_t = std::vector<c10::optional<double>>;
template <typename scalar_t, typename scale_type>
void cpu_upsample_linear_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
" for `grad_input` but got dtype ", grad_input_.dtype());
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.data_ptr<scalar_t>();
auto input_sizes = grad_input.sizes().vec();
auto output_sizes = grad_output.sizes().vec();
auto ndim = input_sizes.size();
// treat nbatch and channels as one dimension
int64_t channels = input_sizes[0] * input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t output_slice_size = output_depth * output_height * output_width;
auto loop1d = [&](int64_t begin, int64_t end) {
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[0]);
auto input_indexr = [=](int64_t c, int64_t w) {
return grad_input_data + c * input_width + w;
};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t iw0, iw1;
scalar_t w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
scalar_t grad_output_value = grad_output_data[c * output_slice_size + ow];
*input_indexr(c, iw0) += w0lambda * grad_output_value; /* i0 */
*input_indexr(c, iw1) += w1lambda * grad_output_value; /* i1*/
}
}
};
auto loop2d = [&](int64_t begin, int64_t end) {
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners, scales[0]);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[1]);
auto input_indexr = [=](int64_t c, int64_t h, int64_t w){
return grad_input_data + c * input_height * input_width + h * input_width + w;
};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t ih0, ih1, iw0, iw1;
scalar_t h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
for (const auto oh : c10::irange(output_height)) {
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
scalar_t grad_output_value = grad_output_data[c * output_slice_size + oh * output_width + ow];
*input_indexr(c, ih0, iw0) += h0lambda * w0lambda * grad_output_value; /* i00 */
*input_indexr(c, ih0, iw1) += h0lambda * w1lambda * grad_output_value; /* i01 */
*input_indexr(c, ih1, iw0) += h1lambda * w0lambda * grad_output_value; /* i10 */
*input_indexr(c, ih1, iw1) += h1lambda * w1lambda * grad_output_value; /* i11 */
}
}
}
};
auto loop3d = [&](int64_t begin, int64_t end) {
const scalar_t depth_scale = area_pixel_compute_scale<scalar_t>(
input_depth, output_depth, align_corners, scales[0]);
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners, scales[1]);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[2]);
auto input_indexr = [=](int64_t c, int64_t d, int64_t h, int64_t w) {
return grad_input_data + c * input_depth * input_height * input_width +
d * input_height * input_width + h * input_width + w;
};
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t id0, id1, ih0, ih1, iw0, iw1;
scalar_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
for (const auto od : c10::irange(output_depth)) {
compute_source_index_and_lambda(
id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
for (const auto oh : c10::irange(output_height)) {
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
scalar_t grad_output_value = grad_output_data[c * output_slice_size +
od * output_height * output_width + oh * output_width + ow];
*input_indexr(c, id0, ih0, iw0) += d0lambda * h0lambda * w0lambda * grad_output_value; /* i000 */
*input_indexr(c, id0, ih0, iw1) += d0lambda * h0lambda * w1lambda * grad_output_value; /* i001 */
*input_indexr(c, id0, ih1, iw0) += d0lambda * h1lambda * w0lambda * grad_output_value; /* i010 */
*input_indexr(c, id0, ih1, iw1) += d0lambda * h1lambda * w1lambda * grad_output_value; /* i011 */
*input_indexr(c, id1, ih0, iw0) += d1lambda * h0lambda * w0lambda * grad_output_value; /* i100 */
*input_indexr(c, id1, ih0, iw1) += d1lambda * h0lambda * w1lambda * grad_output_value; /* i101 */
*input_indexr(c, id1, ih1, iw0) += d1lambda * h1lambda * w0lambda * grad_output_value; /* i110 */
*input_indexr(c, id1, ih1, iw1) += d1lambda * h1lambda * w1lambda * grad_output_value; /* i111 */
}
}
}
}
};
if (ndim == 3) {
// upsample linear 1d
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 2, loop1d);
} else if (ndim == 4) {
// upsample bilinear 2d
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
// upsample trilinear 3d
TORCH_INTERNAL_ASSERT(ndim == 5);
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
}
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
void upsample_linear1d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "upsample_linear1d_backward", [&] {
cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_w});
});
}
void upsample_bilinear2d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "upsample_bilinear2d_backward", [&] {
cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_h, scales_w});
});
}
void upsample_trilinear3d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_d,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "upsample_trilinear3d_backward", [&] {
cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_d, scales_h, scales_w});
});
}
} // anonymous namespace
REGISTER_DISPATCH(upsample_linear1d_backward_kernel, &upsample_linear1d_backward_kernel_impl);
REGISTER_DISPATCH(upsample_bilinear2d_backward_kernel, &upsample_bilinear2d_backward_kernel_impl);
REGISTER_DISPATCH(upsample_trilinear3d_backward_kernel, &upsample_trilinear3d_backward_kernel_impl);
} // namespace native
} // namespace at