From 5b46b40f58ba2d0146f10cab007d50fc008e7ce0 Mon Sep 17 00:00:00 2001 From: Yasha Bubnov Date: Fri, 19 Apr 2024 20:11:11 +0200 Subject: [PATCH] Add a simple functional unit test of quad_pool2d. --- torch_geopooling/functional_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 torch_geopooling/functional_test.py diff --git a/torch_geopooling/functional_test.py b/torch_geopooling/functional_test.py new file mode 100644 index 0000000..23ce122 --- /dev/null +++ b/torch_geopooling/functional_test.py @@ -0,0 +1,16 @@ +import torch +from torch_geopooling.functional import quad_pool2d + + +def test_quad_pool2d() -> None: + tiles = torch.tensor([[0, 0, 0]], dtype=torch.int32) + input = torch.rand((100, 2), dtype=torch.float64) * 10.0 + weight = torch.randn([64], dtype=torch.float64) + bias = torch.randn([64], dtype=torch.float64) + + result = quad_pool2d(tiles, input, weight, bias, (0.0, 0.0, 10.0, 10.0), True) + assert result.tiles.size(0) > 0 + assert result.tiles.size(1) == 3 + + assert result.weight.size(0) > 0 + assert result.bias.size(0) > 0