Skip to content

Commit

Permalink
Prevent inlining of unrolled_reduce in nonorogoraphic gravity wave code
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Oct 16, 2024
1 parent 3caa650 commit 0753c18
Showing 1 changed file with 41 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -566,53 +566,57 @@ function waveforcing_column_accumulate!(
if level >= source_level - 1
# check break condition for each gravity waves and calculate momentum flux of breaking gravity waves at each level
# We use the unrolled_reduce function here because it performs better for parallel execution on the GPU, avoiding type instabilities.
(mask, fm) =
unrolled_reduce(Val(nc), (mask, FT1(0.0))) do (mask, fm), (n)
if (mask[n]) == true
c_hat = c[n] - u_kp1 # c0mu
# f phase speed matches the wind speed, remove c(n) from the set of propagating waves.
if c_hat == 0.0
# However, we need to prevent it from being inlined on the CPU to avoid large compilation times for several test cases in CI.
# Note that @noinline has no effect on the GPU, which requires all kernel code to be inlined.
(mask, fm) = @noinline unrolled_reduce(
StaticOneTo(nc),
(mask, FT1(0.0)),
) do (mask, fm), (n)
if (mask[n]) == true
c_hat = c[n] - u_kp1 # c0mu
# f phase speed matches the wind speed, remove c(n) from the set of propagating waves.
if c_hat == 0.0
mask = Base.setindex(mask, false, n)
else
c_hat0 = c[n] - u_source
# define the criterion which determines if wave is reflected at this level (test).
test = abs(c_hat) * kwv - ω_r
if test >= 0.0
# wave has undergone total internal reflection. remove it from the propagating set.
mask = Base.setindex(mask, false, n)
else
c_hat0 = c[n] - u_source
# define the criterion which determines if wave is reflected at this level (test).
test = abs(c_hat) * kwv - ω_r
if test >= 0.0
# wave has undergone total internal reflection. remove it from the propagating set.
if level == level_end
# this is added in MiMA implementation:
# all momentum flux that escapes across the model top
# is deposited to the extra level being added so that
# momentum flux is conserved
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0[n]
end
else
if level == level_end
# this is added in MiMA implementation:
# all momentum flux that escapes across the model top
# is deposited to the extra level being added so that
# momentum flux is conserved
# if wave is not reflected at this level, determine if it is
# breaking at this level (Foc >= 0), or if wave speed relative to
# windspeed has changed sign from its value at the source level
# (c_hat0[n] * c_hat <= 0). if it is above the source level and is
# breaking, then add its momentum flux to the accumulated sum at
# this level.
# set mask=0.0 to remove phase speed band c[n] from the set of active
# waves moving upwards to the next level.
Foc = B0[n] / (c_hat)^3 - fac
if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0)
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0[n]
end
else
# if wave is not reflected at this level, determine if it is
# breaking at this level (Foc >= 0), or if wave speed relative to
# windspeed has changed sign from its value at the source level
# (c_hat0[n] * c_hat <= 0). if it is above the source level and is
# breaking, then add its momentum flux to the accumulated sum at
# this level.
# set mask=0.0 to remove phase speed band c[n] from the set of active
# waves moving upwards to the next level.
Foc = B0[n] / (c_hat)^3 - fac
if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0)
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0[n]
end
end
end
end # (test >= 0.0)
end
end # (test >= 0.0)

end #(c_hat == 0.0)
end # mask = 0
return (mask, fm)
end
end #(c_hat == 0.0)
end # mask = 0
return (mask, fm)
end

# compute the gravity wave momentum flux forcing
# obtained across the entire wave spectrum at this level.
Expand Down

0 comments on commit 0753c18

Please sign in to comment.