Skip to content

Commit

Permalink
fixed optional arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfhm committed Jan 17, 2024
1 parent e3234d8 commit 8b785b9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
54 changes: 49 additions & 5 deletions transformations/tests/test_raw_stack_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
real(kind=selected_real_kind(13,300)), dimension(nlon, klev) :: zzy
logical, dimension(nlon, klev) :: zzl
integer :: jl, jlev
integer(kind=jpim) :: testint
integer(kind=jpim) :: jl, jlev
zzl = .false.
do jl =1, nlon
Expand All @@ -142,7 +143,7 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
enddo
enddo
call kernel2(ydml_phy_mf%yrphy, nlon, klev, jstart, jend)
call kernel2(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, testint)
call kernel3(ydml_phy_mf%yrphy, nlon, klev, jstart, jend, pzz)
end subroutine kernel1
Expand All @@ -152,7 +153,7 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
fcode_kernel2 = """
module kernel2_mod
contains
subroutine kernel2(ydphy, nlon, klev, jstart, jend)
subroutine kernel2(ydphy, nlon, klev, jstart, jend, testint)
use parkind1, only: jpim, jprb
Expand All @@ -166,6 +167,7 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
integer(kind=jpim), intent(in) :: klev
integer(kind=jpim), intent(in) :: jstart
integer(kind=jpim), intent(in) :: jend
integer(kind=jpim), optional, intent(in) :: testint
integer(kind=jpim) :: jb, jlev, jl
Expand Down Expand Up @@ -213,6 +215,7 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
real(kind=jprb) :: zde2(nlon, klev, ydphy%n_spband)
real(kind=jprb) :: zde3(nlon, 1:klev)
!$acc data present(pzz)
do jb = 1, ydphy%n_spband
zde1(:, 0, jb) = 0._jprb
Expand All @@ -229,6 +232,8 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
zde3(1:nlon,1:klev) = pzz
!$acc iend data
end subroutine kernel3
end module kernel3_mod
""".strip()
Expand Down Expand Up @@ -275,6 +280,8 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct

driver_item = scheduler['driver_mod#driver']
kernel1_item = scheduler['kernel1_mod#kernel1']
kernel2_item = scheduler['kernel2_mod#kernel2']
kernel3_item = scheduler['kernel3_mod#kernel3']

assert transformation._key in kernel1_item.trafo_data

Expand Down Expand Up @@ -309,6 +316,8 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct

driver = driver_item.routine
kernel1 = kernel1_item.routine
kernel2 = kernel2_item.routine
kernel3 = kernel3_item.routine

assert 'j_ll_stack_size' in driver.variable_map
assert 'll_stack' in driver.variable_map
Expand Down Expand Up @@ -399,15 +408,36 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
'jend\n'\
'k_p_selected_real_kind_13_300_stack_size - j_p_selected_real_kind_13_300_stack_used\n'\
'p_selected_real_kind_13_300_stack'\
'(1:nlon, j_p_selected_real_kind_13_300_stack_used + 1:k_p_selected_real_kind_13_300_stack_size)'
'(1:nlon, j_p_selected_real_kind_13_300_stack_used + 1:k_p_selected_real_kind_13_300_stack_size)\n'\
'testint'
else:
assert fgen(calls[0].arguments).lower() == 'ydml_phy_mf%yrphy\n'\
'nlon\n'\
'klev\n'\
'jstart\n'\
'jend\n'\
'k_p_jprb_stack_size - j_p_jprb_stack_used\n'\
'p_jprb_stack(1:nlon, j_p_jprb_stack_used + 1:k_p_jprb_stack_size)'
'p_jprb_stack(1:nlon, j_p_jprb_stack_used + 1:k_p_jprb_stack_size)\n'\
'testint'

if frontend == OMNI:
assert fgen(kernel2.arguments).lower() == 'ydphy\n'\
'nlon\n'\
'klev\n'\
'jstart\n'\
'jend\n'\
'k_p_selected_real_kind_13_300_stack_size\n'\
'p_selected_real_kind_13_300_stack(nlon, k_p_selected_real_kind_13_300_stack_size)\n'\
'testint'
else:
assert fgen(kernel2.arguments).lower() == 'ydphy\n'\
'nlon\n'\
'klev\n'\
'jstart\n'\
'jend\n'\
'k_p_jprb_stack_size\n'\
'p_jprb_stack(nlon, k_p_jprb_stack_size)\n'\
'testint'

assignments = FindNodes(Assignment).visit(driver.body)

Expand Down Expand Up @@ -449,4 +479,18 @@ def test_raw_stack_allocator_temporaries(frontend, block_dim, horizontal, direct
assert pragmas[0].content.lower() == 'target allocate(z_jprb_stack, '\
'z_selected_real_kind_13_300_stack, ll_stack)'

if directive == 'openacc':
pragmas = FindNodes(Pragma).visit(kernel1.body)
if frontend == OMNI:
assert pragmas[0].content.lower() == 'data present(p_selected_real_kind_13_300_stack, ld_stack)'
else:
assert pragmas[0].content.lower() == 'data present(p_jprb_stack, '\
'p_selected_real_kind_13_300_stack, ld_stack)'

pragmas = FindNodes(Pragma).visit(kernel3.body)
if frontend == OMNI:
assert pragmas[0].content.lower() == 'data present(p_selected_real_kind_13_300_stack, pzz)'
else:
assert pragmas[0].content.lower() == 'data present(p_jprb_stack, pzz)'

rmtree(basedir)
3 changes: 2 additions & 1 deletion transformations/transformations/raw_stack_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def insert_stack_in_calls(self, routine, stack_arg_dict, successors):

arguments = call.arguments
if arg_pos:
arguments = arguments[:arg_pos[0]] + as_tuple(call_stack_args) + arguments[arg_pos[0]:]
arg_pos = min(arg_pos) - len(call_stack_args)
arguments = arguments[:arg_pos] + as_tuple(call_stack_args) + arguments[arg_pos:]
else:
arguments += as_tuple(call_stack_args)

Expand Down

0 comments on commit 8b785b9

Please sign in to comment.