diff --git a/src/embedding.cc b/src/embedding.cc index 742ac3a..e28e395 100644 --- a/src/embedding.cc +++ b/src/embedding.cc @@ -84,8 +84,8 @@ check_shape_forward( ) { TORCH_CHECK( - options.exterior.size() == 4, - op, ": exterior must be a tuple of four doubles comprising a rectangle (x, y, w, h)" + options.exterior.size() == 4, op, + ": exterior must be a tuple of four doubles comprising a rectangle (x, y, w, h)" ); TORCH_CHECK(options.padding.size() == 2, op, ": padding should be comprised of 2 elements"); @@ -119,9 +119,8 @@ check_shape_backward( ); TORCH_CHECK( - grad.sizes() == grad_sizes, - op, - ": gradient shape (", grad.sizes(), ") should be the same as input (", input.sizes(), ")" + grad.sizes() == grad_sizes, op, ": gradient shape (", grad.sizes(), + ") should be the same as input (", input.sizes(), ")" ); } diff --git a/test/embedding_test.cc b/test/embedding_test.cc index fa9e302..33d7430 100644 --- a/test/embedding_test.cc +++ b/test/embedding_test.cc @@ -105,9 +105,8 @@ BOOST_AUTO_TEST_CASE(embedding2d_backward_grad) auto padding = std::vector({1, 1}); auto exterior = std::vector({0.0, 0.0, 10.0, 10.0}); - auto grad_weight = embedding2d_backward( - grad, input, weight, /*padding=*/padding, /*exterior=*/exterior - ); + auto grad_weight + = embedding2d_backward(grad, input, weight, /*padding=*/padding, /*exterior=*/exterior); BOOST_REQUIRE_EQUAL(grad_weight.sizes(), weight.sizes()); // Grad of weight represents the following field: @@ -118,17 +117,22 @@ BOOST_AUTO_TEST_CASE(embedding2d_backward_grad) // [3]| 1| 11| 10| 11| // [0]| [1]| [2]| [3]| - auto grad_expect = torch::tensor({ - {1.0, 1.0, 0.0, 1.0}, - {1.0, 11.0, 10.0, 11.0}, - {0.0, 10.0, 10.0, 10.0}, - {1.0, 11.0, 10.0, 11.0}, - }, options); + auto grad_expect = torch::tensor( + { + {1.0, 1.0, 0.0, 1.0}, + {1.0, 11.0, 10.0, 11.0}, + {0.0, 10.0, 10.0, 10.0}, + {1.0, 11.0, 10.0, 11.0}, + }, + options + ); for (auto i : c10::irange(weight.size(0))) { for (auto j : c10::irange(weight.size(1))) { for (auto k : c10::irange(weight.size(2))) { - BOOST_CHECK_EQUAL(grad_weight[i][j][k].item(), grad_expect[i][j].item()); + BOOST_CHECK_EQUAL( + grad_weight[i][j][k].item(), grad_expect[i][j].item() + ); } } }