diff --git a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl index d78dc97aca..b8e59278c5 100644 --- a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl +++ b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl @@ -566,53 +566,57 @@ function waveforcing_column_accumulate!( if level >= source_level - 1 # check break condition for each gravity waves and calculate momentum flux of breaking gravity waves at each level # We use the unrolled_reduce function here because it performs better for parallel execution on the GPU, avoiding type instabilities. - (mask, fm) = - unrolled_reduce(Val(nc), (mask, FT1(0.0))) do (mask, fm), (n) - if (mask[n]) == true - c_hat = c[n] - u_kp1 # c0mu - # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. - if c_hat == 0.0 + # However, we need to prevent it from being inlined on the CPU to avoid large compilation times for several test cases in CI. + # Note that @noinline has no effect on the GPU, which requires all kernel code to be inlined. + (mask, fm) = @noinline unrolled_reduce( + StaticOneTo(nc), + (mask, FT1(0.0)), + ) do (mask, fm), (n) + if (mask[n]) == true + c_hat = c[n] - u_kp1 # c0mu + # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. + if c_hat == 0.0 + mask = Base.setindex(mask, false, n) + else + c_hat0 = c[n] - u_source + # define the criterion which determines if wave is reflected at this level (test). + test = abs(c_hat) * kwv - ω_r + if test >= 0.0 + # wave has undergone total internal reflection. remove it from the propagating set. mask = Base.setindex(mask, false, n) else - c_hat0 = c[n] - u_source - # define the criterion which determines if wave is reflected at this level (test). - test = abs(c_hat) * kwv - ω_r - if test >= 0.0 - # wave has undergone total internal reflection. remove it from the propagating set. + if level == level_end + # this is added in MiMA implementation: + # all momentum flux that escapes across the model top + # is deposited to the extra level being added so that + # momentum flux is conserved mask = Base.setindex(mask, false, n) + if level >= source_level + fm = fm + B0[n] + end else - if level == level_end - # this is added in MiMA implementation: - # all momentum flux that escapes across the model top - # is deposited to the extra level being added so that - # momentum flux is conserved + # if wave is not reflected at this level, determine if it is + # breaking at this level (Foc >= 0), or if wave speed relative to + # windspeed has changed sign from its value at the source level + # (c_hat0[n] * c_hat <= 0). if it is above the source level and is + # breaking, then add its momentum flux to the accumulated sum at + # this level. + # set mask=0.0 to remove phase speed band c[n] from the set of active + # waves moving upwards to the next level. + Foc = B0[n] / (c_hat)^3 - fac + if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) mask = Base.setindex(mask, false, n) if level >= source_level fm = fm + B0[n] end - else - # if wave is not reflected at this level, determine if it is - # breaking at this level (Foc >= 0), or if wave speed relative to - # windspeed has changed sign from its value at the source level - # (c_hat0[n] * c_hat <= 0). if it is above the source level and is - # breaking, then add its momentum flux to the accumulated sum at - # this level. - # set mask=0.0 to remove phase speed band c[n] from the set of active - # waves moving upwards to the next level. - Foc = B0[n] / (c_hat)^3 - fac - if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) - mask = Base.setindex(mask, false, n) - if level >= source_level - fm = fm + B0[n] - end - end end - end # (test >= 0.0) + end + end # (test >= 0.0) - end #(c_hat == 0.0) - end # mask = 0 - return (mask, fm) - end + end #(c_hat == 0.0) + end # mask = 0 + return (mask, fm) + end # compute the gravity wave momentum flux forcing # obtained across the entire wave spectrum at this level.