Skip to content

Commit

Permalink
Resize fix for new torch. (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam authored Nov 20, 2023
1 parent 4d7ab74 commit 694d237
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from waveprop.simulation import FarFieldSimulator
import torch
from waveprop.devices import sensor_dict, SensorParam


def test_far_field_simulator():

sensor = "rpi_hq"
sensor_shape = sensor_dict[sensor][SensorParam.SHAPE]


sim = FarFieldSimulator(
object_height=30e-2,
scene2mask=30e-2,
mask2sensor=4e-3,
sensor=sensor,
output_dim=sensor_shape,
)

obj = torch.rand(1, 1, 256, 256)
image = sim.propagate(obj)
assert image.shape == (1, 1, *sensor_shape)


if __name__ == "__main__":
test_far_field_simulator()
1 change: 1 addition & 0 deletions waveprop/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def prepare_object_plane(
object_height_pix = int(np.round(object_height / scene_dim[1] * sensor_dim[1]))
scaling = object_height_pix / input_dim[1]
object_dim = tuple((np.round(input_dim * scaling)).astype(int))
object_dim = (int(object_dim[0]), int(object_dim[1]))

if torch.is_tensor(obj):
object_plane = resize_torch(obj, size=object_dim, antialias=True)
Expand Down

0 comments on commit 694d237

Please sign in to comment.