diff --git a/include/torch_geopooling/embedding.h b/include/torch_geopooling/embedding.h index cd115eb..0328a4e 100644 --- a/include/torch_geopooling/embedding.h +++ b/include/torch_geopooling/embedding.h @@ -26,7 +26,17 @@ torch::Tensor embedding2d( const torch::Tensor& input, const torch::Tensor& weight, - const c10::ArrayRef& padding, + const c10::IntArrayRef& padding, + const c10::ArrayRef& exterior +); + + +torch::Tensor +embedding2d_backward( + const torch::Tensor& grad, + const torch::Tensor& input, + const torch::Tensor& weight, + const c10::IntArrayRef& padding, const c10::ArrayRef& exterior ); diff --git a/src/embedding.cc b/src/embedding.cc index 0b6e180..742ac3a 100644 --- a/src/embedding.cc +++ b/src/embedding.cc @@ -41,28 +41,101 @@ modulo(int64_t base, int64_t value) } -torch::Tensor -embedding2d( +struct embedding_options { + std::vector padding; + std::vector exterior; + + std::vector + kernel_size(int64_t feature_size) const + { + return {kernel_width(), kernel_height(), feature_size}; + } + + int64_t + kernel_width() const + { + return padding[0] * 2 + 1; + } + + int64_t + kernel_height() const + { + return padding[1] * 2 + 1; + } + + bool + is_padding_neg() const + { + bool is_non_neg = false; + for (const auto& p : padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } +}; + + +static void +check_shape_forward( + const std::string& op, const torch::Tensor& input, const torch::Tensor& weight, - const c10::ArrayRef& padding, - const c10::ArrayRef& exterior + const embedding_options& options ) { - const std::string op = "embedding2d"; + TORCH_CHECK( + options.exterior.size() == 4, + op, ": exterior must be a tuple of four doubles comprising a rectangle (x, y, w, h)" + ); + TORCH_CHECK(options.padding.size() == 2, op, ": padding should be comprised of 2 elements"); + + TORCH_CHECK(input.dim() == 2, op, ": input must be 2D, got ", input.dim(), "D"); + TORCH_CHECK( + input.dtype() == torch::kFloat64, op, ": operation only supports Float64 input, got ", + input.dtype() + ); - TORCH_CHECK(exterior.size() == 4, op, ": exterior must be a tuple of four doubles"); - TORCH_CHECK(padding.size() == 2, op, ": padding should be comprised of 2 elements"); TORCH_CHECK(weight.dim() == 3, op, ": weight must be 3D, got ", weight.dim(), "D"); TORCH_CHECK( weight.dtype() == torch::kFloat64, op, ": operation only supports Float64 weight, got ", weight.dtype() ); - TORCH_CHECK(input.dim() == 2, op, ": input must be 2D, got ", input.dim(), "D"); +} + + +static void +check_shape_backward( + const std::string& op, + const torch::Tensor& grad, + const torch::Tensor& input, + const torch::Tensor& weight, + const embedding_options& options +) +{ + check_shape_forward(op, input, weight, options); + + const auto grad_sizes = c10::IntArrayRef( + {input.size(0), options.kernel_width(), options.kernel_height(), weight.size(-1)} + ); + TORCH_CHECK( - input.dtype() == torch::kFloat64, op, ": operation only supports Float64 input, got ", - input.dtype() + grad.sizes() == grad_sizes, + op, + ": gradient shape (", grad.sizes(), ") should be the same as input (", input.sizes(), ")" ); +} + + +torch::Tensor +embedding2d( + const torch::Tensor& input, + const torch::Tensor& weight, + const c10::ArrayRef& padding, + const c10::ArrayRef& exterior +) +{ + auto options = embedding_options{.padding = padding.vec(), .exterior = exterior.vec()}; + check_shape_forward("embedding2d", input, weight, options); auto width_size = weight.size(0); auto height_size = weight.size(1); @@ -74,34 +147,32 @@ embedding2d( auto quad_width = quad_exterior.width() / width_size; auto quad_height = quad_exterior.height() / height_size; - auto weight_ptr = weight.accessor(); - auto input_ptr = input.accessor(); + auto weight_data = weight.accessor(); + auto input_data = input.accessor(); auto input_size = input.size(0); - const auto kernel_size = std::vector({padding[0] * 2 + 1, padding[1] * 2 + 1, feature_size}); + const auto kernel_size = options.kernel_size(feature_size); std::vector output(input_size); at::parallel_for(0, input_size, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) { for (const auto i : c10::irange(begin, end)) { - const auto& point = input_ptr[i]; + const auto& point = input_data[i]; auto kernel = torch::empty(kernel_size, weight.options()); - auto kernel_ptr = kernel.accessor(); + auto kernel_data = kernel.accessor(); - const auto x = floordiv(point[0] - quad_exterior.xmin(), quad_width); - const auto y = floordiv(point[1] - quad_exterior.ymin(), quad_height); + const auto w = floordiv(point[0] - quad_exterior.xmin(), quad_width); + const auto h = floordiv(point[1] - quad_exterior.ymin(), quad_height); int64_t k0 = 0; - for (auto j0 : c10::irange(x - padding[0], x + padding[0] + 1)) { + for (auto j0 : c10::irange(w - padding[0], w + padding[0] + 1)) { int64_t k1 = 0; - for (auto j1 : c10::irange(y - padding[1], y + padding[1] + 1)) { + for (auto j1 : c10::irange(h - padding[1], h + padding[1] + 1)) { j0 = modulo(width_size, j0); j1 = modulo(height_size, j1); - int64_t k2 = 0; for (auto j2 : c10::irange(feature_size)) { - kernel_ptr[k0][k1][k2] = weight_ptr[j0][j1][j2]; - k2++; + kernel_data[k0][k1][j2] = weight_data[j0][j1][j2]; } k1++; } @@ -116,4 +187,60 @@ embedding2d( } +torch::Tensor +embedding2d_backward( + const torch::Tensor& grad, + const torch::Tensor& input, + const torch::Tensor& weight, + const c10::IntArrayRef& padding, + const c10::ArrayRef& exterior +) +{ + auto options = embedding_options{.padding = padding.vec(), .exterior = exterior.vec()}; + check_shape_backward("embedding2d_backward", grad, input, weight, options); + + auto width_size = weight.size(0); + auto height_size = weight.size(1); + auto feature_size = weight.size(2); + + auto quad_exterior = quadrect(exterior.vec()); + auto quad_width = quad_exterior.width() / width_size; + auto quad_height = quad_exterior.height() / height_size; + + const int64_t weight_numel = weight.numel(); + const int64_t input_size = input.size(0); + const int64_t grain_size = at::internal::GRAIN_SIZE; + + auto input_data = input.accessor(); + auto grad_weight = at::zeros(weight.sizes(), grad.options()); + + at::parallel_for(0, weight_numel, grain_size, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(input_size)) { + const auto& point = input_data[i]; + + const auto w = floordiv(point[0] - quad_exterior.xmin(), quad_width); + const auto h = floordiv(point[1] - quad_exterior.ymin(), quad_height); + + int64_t k0 = 0; + for (auto j0 : c10::irange(w - padding[0], w + padding[0] + 1)) { + int64_t k1 = 0; + for (auto j1 : c10::irange(h - padding[1], h + padding[1] + 1)) { + j0 = modulo(width_size, j0); + j1 = modulo(height_size, j1); + + int64_t pos = j0 * width_size + j1; + if (pos >= begin && pos < end) { + grad_weight[j0][j1] += grad[i][k0][k1]; + } + k1++; + } + k0++; + } + } + }); + + return grad_weight; +} + + } // namespace torch_geopooling diff --git a/test/embedding_test.cc b/test/embedding_test.cc index d918a17..fa9e302 100644 --- a/test/embedding_test.cc +++ b/test/embedding_test.cc @@ -60,11 +60,11 @@ BOOST_AUTO_TEST_CASE(embedding2d_eval) // Weight represents the following field: // - // [3]|0.0 |2.5 |5.0 |7.5 |10.0 - // [2]|0.0 |2.5 |5.0 |7.5 |10.0 - // [1]|0.0 |2.5 |5.0 |7.5 |10.0 - // [0]|0.0 |2.5 |5.0 |7.5 |10.0 // [0]| [1]| [2]| [3]| + // [0]|0.0 |2.5 |5.0 |7.5 |≤10.0 + // [1]|0.0 |2.5 |5.0 |7.5 |≤10.0 + // [2]|0.0 |2.5 |5.0 |7.5 |≤10.0 + // [3]|0.0 |2.5 |5.0 |7.5 |≤10.0 auto input = torch::tensor( { {1.0, 1.0}, // center at (0, 0) @@ -83,4 +83,56 @@ BOOST_AUTO_TEST_CASE(embedding2d_eval) } +BOOST_AUTO_TEST_CASE(embedding2d_backward_grad) +{ + auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + + auto weight = torch::full({4, 4, 5}, 1.0, options); + auto weight_ptr = weight.accessor(); + + auto input = torch::tensor( + { + {1.0, 1.0}, // center at (0, 0) + {6.0, 6.0}, // center at (2, 2) + }, + options + ); + + auto grad = torch::zeros({2, 3, 3, 5}); + grad.index_put_({0, Ellipsis, Ellipsis, Ellipsis}, 1.0); + grad.index_put_({1, Ellipsis, Ellipsis, Ellipsis}, 10.0); + + auto padding = std::vector({1, 1}); + auto exterior = std::vector({0.0, 0.0, 10.0, 10.0}); + + auto grad_weight = embedding2d_backward( + grad, input, weight, /*padding=*/padding, /*exterior=*/exterior + ); + + BOOST_REQUIRE_EQUAL(grad_weight.sizes(), weight.sizes()); + // Grad of weight represents the following field: + // + // [0]| 1| 1| 0| 0| + // [1]| 1| 11| 10| 10| + // [2]| 0| 10| 10| 10| + // [3]| 1| 11| 10| 11| + // [0]| [1]| [2]| [3]| + + auto grad_expect = torch::tensor({ + {1.0, 1.0, 0.0, 1.0}, + {1.0, 11.0, 10.0, 11.0}, + {0.0, 10.0, 10.0, 10.0}, + {1.0, 11.0, 10.0, 11.0}, + }, options); + + for (auto i : c10::irange(weight.size(0))) { + for (auto j : c10::irange(weight.size(1))) { + for (auto k : c10::irange(weight.size(2))) { + BOOST_CHECK_EQUAL(grad_weight[i][j][k].item(), grad_expect[i][j].item()); + } + } + } +} + + BOOST_AUTO_TEST_SUITE_END()