Skip to content

Commit

Permalink
reintroduced tests for test_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
biphasic committed Feb 7, 2025
1 parent 8558111 commit d521774
Showing 1 changed file with 164 additions and 0 deletions.
164 changes: 164 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,167 @@ def test_transform_spatial_jitter(variance, clip_outliers):

else:
assert len(events)


@pytest.mark.parametrize(
"std, clip_negative, sort_timestamps",
[(10, True, True), (50, False, False), (0, True, False)],
)
def test_transform_time_jitter(std, clip_negative, sort_timestamps):
orig_events, sensor_size = create_random_input()

transform = transforms.TimeJitter(
std=std, clip_negative=clip_negative, sort_timestamps=sort_timestamps
)

events = transform(orig_events)

if clip_negative:
assert (events["t"] >= 0).all()
else:
assert len(events) == len(orig_events)
if sort_timestamps:
np.testing.assert_array_equal(events["t"], np.sort(events["t"]))
if not sort_timestamps and not clip_negative:
np.testing.assert_array_equal(events["x"], orig_events["x"])
np.testing.assert_array_equal(events["y"], orig_events["y"])
np.testing.assert_array_equal(events["p"], orig_events["p"])
assert (
events["t"] - orig_events["t"]
== (events["t"] - orig_events["t"]).astype(int)
).all()
assert events is not orig_events


@pytest.mark.parametrize("p", [1, 0])
def test_transform_time_reversal(p):
orig_events, sensor_size = create_random_input()

original_t = orig_events["t"][0]
max_t = np.max(orig_events["t"])

transform = transforms.RandomTimeReversal(p=p)
events = transform(orig_events)

if p == 1:
assert np.array_equal(orig_events["t"], max_t - events["t"][::-1])
assert np.array_equal(
orig_events["p"],
np.invert(events["p"][::-1].astype(bool)),
)
elif p == 0:
assert np.array_equal(orig_events, events)
assert events is not orig_events


@pytest.mark.parametrize("p", [1, 0])
def test_random_reversal_raster(p):
orig_events, sensor_size = create_random_input()
to_raster = transforms.ToFrame(sensor_size=sensor_size, n_event_bins=100)
# raster in shape [t, p, h, w]
orig_raster = to_raster(orig_events)

transform = transforms.RandomTimeReversal(p=p)
raster = transform(orig_raster)

if p == 1:
assert np.array_equal(raster, orig_raster[::-1, ::-1, ...])
elif p == 0:
assert np.array_equal(raster, orig_raster)
assert raster is not orig_raster


@pytest.mark.parametrize("coefficient, offset", [(3.1, 100), (0.3, 0), (2.7, 10)])
def test_transform_time_skew(coefficient, offset):
orig_events, sensor_size = create_random_input()

transform = transforms.TimeSkew(coefficient=coefficient, offset=offset)

events = transform(orig_events)

assert len(events) == len(orig_events)
assert np.min(events["t"]) >= offset
assert (events["t"] == (events["t"]).astype(int)).all()
assert all((orig_events["t"] * coefficient + offset).astype(int) == events["t"])
assert events is not orig_events


@pytest.mark.parametrize("n", [100, 0, (10, 100)])
def test_transform_uniform_noise(n):
orig_events, sensor_size = create_random_input()

if type(n) == tuple:
sampled_n = transforms.UniformNoise.get_params(n=n)
assert n[0] <= sampled_n <= n[1]
assert float(sampled_n).is_integer()
events = transforms.functional.uniform_noise_numpy(
events=orig_events, sensor_size=sensor_size, n=sampled_n
)
assert len(events) == len(orig_events) + sampled_n

else:
transform = transforms.UniformNoise(sensor_size=sensor_size, n=n)
events = transform(orig_events)
assert len(events) == len(orig_events) + n

assert np.isin(orig_events, events).all()
assert np.isclose(
np.sum((events["t"] - np.sort(events["t"])) ** 2), 0
), "Event noise should maintain temporal order."
assert events is not orig_events


@pytest.mark.parametrize("n", [100, 0, (10, 100)])
def test_transform_uniform_noise_empty(n):
orig_events, sensor_size = create_random_input(n_events=0)
assert len(orig_events) == 0

transform = transforms.UniformNoise(sensor_size=sensor_size, n=n)
events = transform(orig_events)
assert len(events) == 0 # check returns an empty array, independent of n.


def test_transform_time_alignment():
orig_events, sensor_size = create_random_input()

transform = transforms.TimeAlignment()

events = transform(orig_events)

assert np.min(events["t"]) == 0
assert events is not orig_events


def test_toframe_empty():
orig_events, sensor_size = create_random_input(n_events=0)
assert len(orig_events) == 0

with pytest.raises(
ValueError
): # check that empty array raises error if no slicing method is specified
transform = transforms.ToFrame(sensor_size=sensor_size)
frame = transform(orig_events)

n_event_bins = 100
transform = transforms.ToFrame(sensor_size=sensor_size, n_event_bins=n_event_bins)
frame = transform(orig_events)
assert frame.shape == (n_event_bins, sensor_size[2], sensor_size[0], sensor_size[1])
assert frame.sum() == 0

n_time_bins = 100
transform = transforms.ToFrame(sensor_size=sensor_size, n_time_bins=n_time_bins)
frame = transform(orig_events)
assert frame.shape == (n_time_bins, sensor_size[2], sensor_size[0], sensor_size[1])
assert frame.sum() == 0

event_count = 1e3
transform = transforms.ToFrame(sensor_size=sensor_size, event_count=event_count)
frame = transform(orig_events)
assert frame.shape == (1, sensor_size[2], sensor_size[0], sensor_size[1])
assert frame.sum() == 0

time_window = 1e3
transform = transforms.ToFrame(sensor_size=sensor_size, time_window=time_window)
frame = transform(orig_events)
assert frame.shape == (1, sensor_size[2], sensor_size[0], sensor_size[1])
assert frame.sum() == 0

0 comments on commit d521774

Please sign in to comment.