Skip to content

Commit

Permalink
Add backward pass for embedding operation
Browse files Browse the repository at this point in the history
This patch implements a backward pass for embedding2d operation.
  • Loading branch information
ybubnov committed Aug 16, 2024
1 parent 6232349 commit adc4cc6
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 27 deletions.
12 changes: 11 additions & 1 deletion include/torch_geopooling/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ torch::Tensor
embedding2d(
const torch::Tensor& input,
const torch::Tensor& weight,
const c10::ArrayRef<int64_t>& padding,
const c10::IntArrayRef& padding,
const c10::ArrayRef<double>& exterior
);


torch::Tensor
embedding2d_backward(
const torch::Tensor& grad,
const torch::Tensor& input,
const torch::Tensor& weight,
const c10::IntArrayRef& padding,
const c10::ArrayRef<double>& exterior
);

Expand Down
171 changes: 149 additions & 22 deletions src/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,101 @@ modulo(int64_t base, int64_t value)
}


torch::Tensor
embedding2d(
struct embedding_options {
std::vector<int64_t> padding;
std::vector<double> exterior;

std::vector<int64_t>
kernel_size(int64_t feature_size) const
{
return {kernel_width(), kernel_height(), feature_size};
}

int64_t
kernel_width() const
{
return padding[0] * 2 + 1;
}

int64_t
kernel_height() const
{
return padding[1] * 2 + 1;
}

bool
is_padding_neg() const
{
bool is_non_neg = false;
for (const auto& p : padding) {
is_non_neg |= (p < 0);
}
return is_non_neg;
}
};


static void
check_shape_forward(
const std::string& op,
const torch::Tensor& input,
const torch::Tensor& weight,
const c10::ArrayRef<int64_t>& padding,
const c10::ArrayRef<double>& exterior
const embedding_options& options
)
{
const std::string op = "embedding2d";
TORCH_CHECK(
options.exterior.size() == 4,
op, ": exterior must be a tuple of four doubles comprising a rectangle (x, y, w, h)"
);
TORCH_CHECK(options.padding.size() == 2, op, ": padding should be comprised of 2 elements");

TORCH_CHECK(input.dim() == 2, op, ": input must be 2D, got ", input.dim(), "D");
TORCH_CHECK(
input.dtype() == torch::kFloat64, op, ": operation only supports Float64 input, got ",
input.dtype()
);

TORCH_CHECK(exterior.size() == 4, op, ": exterior must be a tuple of four doubles");
TORCH_CHECK(padding.size() == 2, op, ": padding should be comprised of 2 elements");
TORCH_CHECK(weight.dim() == 3, op, ": weight must be 3D, got ", weight.dim(), "D");
TORCH_CHECK(
weight.dtype() == torch::kFloat64, op, ": operation only supports Float64 weight, got ",
weight.dtype()
);
TORCH_CHECK(input.dim() == 2, op, ": input must be 2D, got ", input.dim(), "D");
}


static void
check_shape_backward(
const std::string& op,
const torch::Tensor& grad,
const torch::Tensor& input,
const torch::Tensor& weight,
const embedding_options& options
)
{
check_shape_forward(op, input, weight, options);

const auto grad_sizes = c10::IntArrayRef(
{input.size(0), options.kernel_width(), options.kernel_height(), weight.size(-1)}
);

TORCH_CHECK(
input.dtype() == torch::kFloat64, op, ": operation only supports Float64 input, got ",
input.dtype()
grad.sizes() == grad_sizes,
op,
": gradient shape (", grad.sizes(), ") should be the same as input (", input.sizes(), ")"
);
}


torch::Tensor
embedding2d(
const torch::Tensor& input,
const torch::Tensor& weight,
const c10::ArrayRef<int64_t>& padding,
const c10::ArrayRef<double>& exterior
)
{
auto options = embedding_options{.padding = padding.vec(), .exterior = exterior.vec()};
check_shape_forward("embedding2d", input, weight, options);

auto width_size = weight.size(0);
auto height_size = weight.size(1);
Expand All @@ -74,34 +147,32 @@ embedding2d(
auto quad_width = quad_exterior.width() / width_size;
auto quad_height = quad_exterior.height() / height_size;

auto weight_ptr = weight.accessor<double, 3>();
auto input_ptr = input.accessor<double, 2>();
auto weight_data = weight.accessor<double, 3>();
auto input_data = input.accessor<double, 2>();

auto input_size = input.size(0);
const auto kernel_size = std::vector({padding[0] * 2 + 1, padding[1] * 2 + 1, feature_size});
const auto kernel_size = options.kernel_size(feature_size);
std::vector<torch::Tensor> output(input_size);

at::parallel_for(0, input_size, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
const auto& point = input_ptr[i];
const auto& point = input_data[i];

auto kernel = torch::empty(kernel_size, weight.options());
auto kernel_ptr = kernel.accessor<double, 3>();
auto kernel_data = kernel.accessor<double, 3>();

const auto x = floordiv(point[0] - quad_exterior.xmin(), quad_width);
const auto y = floordiv(point[1] - quad_exterior.ymin(), quad_height);
const auto w = floordiv(point[0] - quad_exterior.xmin(), quad_width);
const auto h = floordiv(point[1] - quad_exterior.ymin(), quad_height);

int64_t k0 = 0;
for (auto j0 : c10::irange(x - padding[0], x + padding[0] + 1)) {
for (auto j0 : c10::irange(w - padding[0], w + padding[0] + 1)) {
int64_t k1 = 0;
for (auto j1 : c10::irange(y - padding[1], y + padding[1] + 1)) {
for (auto j1 : c10::irange(h - padding[1], h + padding[1] + 1)) {
j0 = modulo(width_size, j0);
j1 = modulo(height_size, j1);

int64_t k2 = 0;
for (auto j2 : c10::irange(feature_size)) {
kernel_ptr[k0][k1][k2] = weight_ptr[j0][j1][j2];
k2++;
kernel_data[k0][k1][j2] = weight_data[j0][j1][j2];
}
k1++;
}
Expand All @@ -116,4 +187,60 @@ embedding2d(
}


torch::Tensor
embedding2d_backward(
const torch::Tensor& grad,
const torch::Tensor& input,
const torch::Tensor& weight,
const c10::IntArrayRef& padding,
const c10::ArrayRef<double>& exterior
)
{
auto options = embedding_options{.padding = padding.vec(), .exterior = exterior.vec()};
check_shape_backward("embedding2d_backward", grad, input, weight, options);

auto width_size = weight.size(0);
auto height_size = weight.size(1);
auto feature_size = weight.size(2);

auto quad_exterior = quadrect(exterior.vec());
auto quad_width = quad_exterior.width() / width_size;
auto quad_height = quad_exterior.height() / height_size;

const int64_t weight_numel = weight.numel();
const int64_t input_size = input.size(0);
const int64_t grain_size = at::internal::GRAIN_SIZE;

auto input_data = input.accessor<double, 2>();
auto grad_weight = at::zeros(weight.sizes(), grad.options());

at::parallel_for(0, weight_numel, grain_size, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(input_size)) {
const auto& point = input_data[i];

const auto w = floordiv(point[0] - quad_exterior.xmin(), quad_width);
const auto h = floordiv(point[1] - quad_exterior.ymin(), quad_height);

int64_t k0 = 0;
for (auto j0 : c10::irange(w - padding[0], w + padding[0] + 1)) {
int64_t k1 = 0;
for (auto j1 : c10::irange(h - padding[1], h + padding[1] + 1)) {
j0 = modulo(width_size, j0);
j1 = modulo(height_size, j1);

int64_t pos = j0 * width_size + j1;
if (pos >= begin && pos < end) {
grad_weight[j0][j1] += grad[i][k0][k1];
}
k1++;
}
k0++;
}
}
});

return grad_weight;
}


} // namespace torch_geopooling
60 changes: 56 additions & 4 deletions test/embedding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ BOOST_AUTO_TEST_CASE(embedding2d_eval)

// Weight represents the following field:
//
// [3]|0.0 |2.5 |5.0 |7.5 |10.0
// [2]|0.0 |2.5 |5.0 |7.5 |10.0
// [1]|0.0 |2.5 |5.0 |7.5 |10.0
// [0]|0.0 |2.5 |5.0 |7.5 |10.0
// [0]| [1]| [2]| [3]|
// [0]|0.0 |2.5 |5.0 |7.5 |≤10.0
// [1]|0.0 |2.5 |5.0 |7.5 |≤10.0
// [2]|0.0 |2.5 |5.0 |7.5 |≤10.0
// [3]|0.0 |2.5 |5.0 |7.5 |≤10.0
auto input = torch::tensor(
{
{1.0, 1.0}, // center at (0, 0)
Expand All @@ -83,4 +83,56 @@ BOOST_AUTO_TEST_CASE(embedding2d_eval)
}


BOOST_AUTO_TEST_CASE(embedding2d_backward_grad)
{
auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);

auto weight = torch::full({4, 4, 5}, 1.0, options);
auto weight_ptr = weight.accessor<double, 3>();

auto input = torch::tensor(
{
{1.0, 1.0}, // center at (0, 0)
{6.0, 6.0}, // center at (2, 2)
},
options
);

auto grad = torch::zeros({2, 3, 3, 5});
grad.index_put_({0, Ellipsis, Ellipsis, Ellipsis}, 1.0);
grad.index_put_({1, Ellipsis, Ellipsis, Ellipsis}, 10.0);

auto padding = std::vector<int64_t>({1, 1});
auto exterior = std::vector<double>({0.0, 0.0, 10.0, 10.0});

auto grad_weight = embedding2d_backward(
grad, input, weight, /*padding=*/padding, /*exterior=*/exterior
);

BOOST_REQUIRE_EQUAL(grad_weight.sizes(), weight.sizes());
// Grad of weight represents the following field:
//
// [0]| 1| 1| 0| 0|
// [1]| 1| 11| 10| 10|
// [2]| 0| 10| 10| 10|
// [3]| 1| 11| 10| 11|
// [0]| [1]| [2]| [3]|

auto grad_expect = torch::tensor({
{1.0, 1.0, 0.0, 1.0},
{1.0, 11.0, 10.0, 11.0},
{0.0, 10.0, 10.0, 10.0},
{1.0, 11.0, 10.0, 11.0},
}, options);

for (auto i : c10::irange(weight.size(0))) {
for (auto j : c10::irange(weight.size(1))) {
for (auto k : c10::irange(weight.size(2))) {
BOOST_CHECK_EQUAL(grad_weight[i][j][k].item<double>(), grad_expect[i][j].item<double>());
}
}
}
}


BOOST_AUTO_TEST_SUITE_END()

0 comments on commit adc4cc6

Please sign in to comment.