From 694d2374864c41e965622cf8235c4fe17e6d09e8 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 20 Nov 2023 14:11:09 +0100 Subject: [PATCH] Resize fix for new torch. (#13) --- tests/test_simulate.py | 26 ++++++++++++++++++++++++++ waveprop/util.py | 1 + 2 files changed, 27 insertions(+) create mode 100644 tests/test_simulate.py diff --git a/tests/test_simulate.py b/tests/test_simulate.py new file mode 100644 index 0000000..ccc6bba --- /dev/null +++ b/tests/test_simulate.py @@ -0,0 +1,26 @@ +from waveprop.simulation import FarFieldSimulator +import torch +from waveprop.devices import sensor_dict, SensorParam + + +def test_far_field_simulator(): + + sensor = "rpi_hq" + sensor_shape = sensor_dict[sensor][SensorParam.SHAPE] + + + sim = FarFieldSimulator( + object_height=30e-2, + scene2mask=30e-2, + mask2sensor=4e-3, + sensor=sensor, + output_dim=sensor_shape, + ) + + obj = torch.rand(1, 1, 256, 256) + image = sim.propagate(obj) + assert image.shape == (1, 1, *sensor_shape) + + +if __name__ == "__main__": + test_far_field_simulator() \ No newline at end of file diff --git a/waveprop/util.py b/waveprop/util.py index a22939f..c9c6e10 100644 --- a/waveprop/util.py +++ b/waveprop/util.py @@ -560,6 +560,7 @@ def prepare_object_plane( object_height_pix = int(np.round(object_height / scene_dim[1] * sensor_dim[1])) scaling = object_height_pix / input_dim[1] object_dim = tuple((np.round(input_dim * scaling)).astype(int)) + object_dim = (int(object_dim[0]), int(object_dim[1])) if torch.is_tensor(obj): object_plane = resize_torch(obj, size=object_dim, antialias=True)