Skip to content

Commit

Permalink
test[next]: Add unit test for embedded inverse_image and fix bugs (#…
Browse files Browse the repository at this point in the history
…1432)

Add unit tests for `ConnectivityField.inverse_image()`.
  • Loading branch information
egparedes authored Jan 31, 2024
1 parent 6262708 commit e4dc1ee
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,11 @@ def inverse_image(
last_data_index = dim_nnz_indices[-1]
assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES)
indices, counts = xp.unique(dim_nnz_indices, return_counts=True)
dim_range = self._domain[i]

if len(xp.unique(counts)) == 1 and (
len(indices) == last_data_index - first_data_index + 1
):
dim_range = self._domain[i]
idx_offset = dim_range[1].start
start = idx_offset + first_data_index
assert common.is_int_index(start)
Expand All @@ -428,6 +429,8 @@ def inverse_image(
f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'."
)

self._cache[cache_key] = new_dims

return new_dims

def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar:
Expand Down
110 changes: 110 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,113 @@ def test_setitem_wrong_domain():

with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"):
field[(1, slice(None))] = value_incompatible


def test_connectivity_field_inverse_image():
V = Dimension("V")
E = Dimension("E")

V_START, V_STOP = 2, 7
E_START, E_STOP = 0, 10

e2v_conn = common._connectivity(
np.roll(np.arange(E_START, E_STOP), 1),
domain=common.domain([common.named_range((E, (E_START, E_STOP)))]),
codomain=V,
)

# Test range
image_range = UnitRange(V_START, V_STOP)
result = e2v_conn.inverse_image(image_range)

assert len(result) == 1
assert result[0] == (E, UnitRange(V_START + 1, V_STOP + 1))

# Test cache
cached_result = e2v_conn.inverse_image(image_range)
assert result is cached_result # If the cache is not used, the result would be a new object

# Test codomain
with pytest.raises(ValueError, match="does not match the codomain dimension"):
e2v_conn.inverse_image((E, UnitRange(1, 2)))


def test_connectivity_field_inverse_image_2d_domain():
V = Dimension("V")
E = Dimension("E")
E2V = Dimension("E2V")

V_START, V_STOP = 0, 3
E_START, E_STOP = 0, 3
E2V_START, E2V_STOP = 0, 3

e2v_conn = common._connectivity(
np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]),
domain=common.domain(
[
common.named_range((E, (E_START, E_STOP))),
common.named_range((E2V, (E2V_START, E2V_STOP))),
]
),
codomain=V,
)

# e2c_conn:
# ---E2V----
# |[[0 0 2]
# E [1 1 2]
# | [2 2 2]]

# Test contiguous and non-contiguous ranges.
# For the 'e2c_conn' defined above, the only valid range including 2
# is [0, 3). Otherwise, the inverse image would be non-contiguous.
image_range = UnitRange(V_START, V_STOP)
result = e2v_conn.inverse_image(image_range)

assert len(result) == 2
assert result[0] == (E, UnitRange(E_START, E_STOP))
assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP))

result = e2v_conn.inverse_image(UnitRange(0, 2))
assert len(result) == 2
assert result[0] == (E, UnitRange(0, 2))
assert result[1] == (E2V, UnitRange(0, 2))

result = e2v_conn.inverse_image(UnitRange(0, 1))
assert len(result) == 2
assert result[0] == (E, UnitRange(0, 1))
assert result[1] == (E2V, UnitRange(0, 2))

result = e2v_conn.inverse_image(UnitRange(1, 2))
assert len(result) == 2
assert result[0] == (E, UnitRange(1, 2))
assert result[1] == (E2V, UnitRange(0, 2))

with pytest.raises(ValueError, match="generates non-contiguous dimensions"):
result = e2v_conn.inverse_image(UnitRange(1, 3))

with pytest.raises(ValueError, match="generates non-contiguous dimensions"):
result = e2v_conn.inverse_image(UnitRange(2, 3))


def test_connectivity_field_inverse_image_non_contiguous():
V = Dimension("V")
E = Dimension("E")

V_START, V_STOP = 2, 7
E_START, E_STOP = 0, 10

e2v_conn = common._connectivity(
np.asarray([0, 1, 2, 3, 4, 9, 7, 5, 8, 6]),
domain=common.domain([common.named_range((E, (E_START, E_STOP)))]),
codomain=V,
)

result = e2v_conn.inverse_image(UnitRange(V_START, 5))
assert result[0] == (E, UnitRange(V_START, 5))

with pytest.raises(ValueError, match="generates non-contiguous dimensions"):
e2v_conn.inverse_image(UnitRange(V_START, 6))

with pytest.raises(ValueError, match="generates non-contiguous dimensions"):
e2v_conn.inverse_image(UnitRange(V_START, V_STOP))

0 comments on commit e4dc1ee

Please sign in to comment.