Skip to content

Commit

Permalink
Merge pull request #454 from ecmwf-ifs/nams-resolve-vector-notation-d…
Browse files Browse the repository at this point in the history
…ims-specified

resolve_vector_notation: fix/improvement
  • Loading branch information
reuterbal authored Dec 2, 2024
2 parents ba7e230 + 36b1775 commit e9760e8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
8 changes: 5 additions & 3 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def resolve_vector_notation(routine):
ivar_basename = f'i_{stmt.lhs.basename}'
for i, dim, s in zip(count(), v.dimensions, as_tuple(v.shape)):
if isinstance(dim, sym.RangeIndex):
# use the shape for e.g., `ARR(:)`, but use the dimension for e.g., `ARR(2:5)`
_s = dim if dim.lower is not None else s
# create tuple to test whether an appropriate loop is already available
test_range = (sym.IntLiteral(1), s, 1) if not isinstance(s, sym.RangeIndex)\
else (s.lower, s.upper, 1)
test_range = (sym.IntLiteral(1), _s, 1) if not isinstance(_s, sym.RangeIndex)\
else (_s.lower, _s.upper, 1)
# actually test for it
if test_range in loop_map:
# Use index variable of available matching loop
Expand All @@ -208,7 +210,7 @@ def resolve_vector_notation(routine):
vtype = SymbolAttributes(BasicType.INTEGER)
ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine)
shape_index_map[(i, s)] = ivar
index_range_map[ivar] = s
index_range_map[ivar] = _s

if ivar not in vdims:
vdims.append(ivar)
Expand Down
52 changes: 38 additions & 14 deletions loki/transformations/tests/test_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,16 +1064,17 @@ def test_transform_promote_resolve_vector_notation(tmp_path, frontend):


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
@pytest.mark.parametrize('kidia_loop', (True, False))
def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend, kidia_loop):
"""
Apply and test resolve vector notation utility with already
available/appropriate loops.
"""
fcode = """
subroutine transform_resolve_vector_notation_common_loops(scalar, vector, matrix, n, m, l)
fcode = f"""
subroutine transform_resolve_vector_notation_common_loops(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)
implicit none
integer, intent(in) :: n, m, l
integer, intent(inout) :: scalar, vector(n), matrix(l, n)
integer, intent(in) :: n, m, l, kidia, kfdia
integer, intent(inout) :: scalar, vector(n), vector_2(n), matrix(l, n)
integer :: tmp_scalar, tmp_vector(n, m), tmp_matrix(l, m, n), tmp_dummy(n, 0:4)
integer :: jl, jk, jm
Expand All @@ -1083,7 +1084,7 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
tmp_matrix(:, :, :) = 0
matrix(:, :) = 0
do jl=1,n
do jl={'kidia,kfdia' if kidia_loop else '1,n'}
do jm=1,m
tmp_vector(jl, jm) = scalar + jl
end do
Expand All @@ -1110,6 +1111,9 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
end do
end do
vector_2(:) = 1
vector_2(kidia:kfdia) = 2
end subroutine transform_resolve_vector_notation_common_loops
""".strip()
routine = Subroutine.from_source(fcode, frontend=frontend)
Expand All @@ -1120,21 +1124,22 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
n = 3
m = 2
l = 3
kidia = 1
kfdia = n
scalar = np.zeros(shape=(1,), order='F', dtype=np.int32)
vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
vector_2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
function(scalar, vector, matrix, n, m, l)
function(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)

assert all(scalar == 3)
assert np.all(vector == np.arange(1, n + 1)*2)
assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))

resolve_vector_notation(routine)

loops = FindNodes(Loop).visit(routine.body)
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]

assert len(loops) == 19
assert len(loops) == 21
assert loops[0].variable == 'i_tmp_dummy_1' and loops[0].bounds.children == (0, 4, None)
assert loops[1].variable == 'jl' and loops[1].bounds.children == (1, 'n', 1)
assert loops[2].variable == 'jl' and loops[2].bounds.children == (1, 'n', 1)
Expand All @@ -1145,7 +1150,11 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
assert loops[7].variable == 'jk' and loops[7].bounds.children == (1, 'l', 1)
assert loops[8].variable == 'jl' and loops[8].bounds.children == (1, 'n', 1)
assert loops[9].variable == 'jk' and loops[9].bounds.children == (1, 'l', 1)
assert loops[10].variable == 'jl' and loops[10].bounds.children == (1, 'n', None)
assert loops[10].variable == 'jl'
if kidia_loop:
assert loops[10].bounds.children == ('kidia', 'kfdia', None)
else:
assert loops[10].bounds.children == (1, 'n', None)
assert loops[11].variable == 'jm' and loops[11].bounds.children == (1, 'm', None)
assert loops[12].variable == 'jm' and loops[12].bounds.children == (1, 'm', None)
assert loops[13].variable == 'jl' and loops[13].bounds.children == (1, 'n', None)
Expand All @@ -1154,13 +1163,26 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
assert loops[16].variable == 'jl' and loops[16].bounds.children == (1, 'n', 1)
assert loops[17].variable == 'jm' and loops[17].bounds.children == (1, 'm', None)
assert loops[18].variable == 'jl' and loops[18].bounds.children == (1, 'n', None)
assert loops[19].variable == 'jl' and loops[19].bounds.children == (1, 'n', 1)
if kidia_loop:
assert loops[20].variable == 'jl'
assert loops[20].bounds.children == ('kidia', 'kfdia', None)
else:
assert loops[20].variable == 'i_vector_2_0'
assert loops[20].bounds.children == ('kidia', 'kfdia', None)

assert len(arrays) == 15
assert len(arrays) == 17
assert arrays[0].name.lower() == 'tmp_dummy' and arrays[0].dimensions == ('jl', 'i_tmp_dummy_1')
assert arrays[1].name.lower() == 'tmp_vector' and arrays[1].dimensions == ('jl', 1)
assert arrays[2].name.lower() == 'tmp_dummy' and arrays[2].dimensions == ('jl', 1)
assert arrays[3].name.lower() == 'tmp_vector' and arrays[3].dimensions == ('jl', 'jm')
assert arrays[4].name.lower() == 'tmp_matrix' and arrays[4].dimensions == ('jk', 'jm', 'jl')
assert arrays[15].name.lower() == 'vector_2' and arrays[15].dimensions == ('jl',)
assert arrays[16].name.lower() == 'vector_2'
if kidia_loop:
assert arrays[16].dimensions == ('jl',)
else:
assert arrays[16].dimensions == ('i_vector_2_0',)

# Test promoted routine
resolved_filepath = tmp_path/(f'{routine.name}_resolved_{frontend}.f90')
Expand All @@ -1169,16 +1191,18 @@ def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
n = 3
m = 2
l = 3
kidia = 1
kfdia = n
scalar = np.zeros(shape=(1,), order='F', dtype=np.int32)
vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
vector_2 = np.zeros(shape=(n,), order='F', dtype=np.int32)
matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
resolved_function(scalar, vector, matrix, n, m, l)
resolved_function(scalar, vector, vector_2, matrix, n, m, l, kidia, kfdia)

assert all(scalar == 3)
assert np.all(vector == np.arange(1, n + 1)*2)
assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('calls_only', (False, True))
def test_transform_explicit_dimensions(tmp_path, frontend, builder, calls_only):
Expand Down

0 comments on commit e9760e8

Please sign in to comment.