From d5533ad540ea04709bfa37fa7ac86cfb35cfd528 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 8 Oct 2024 14:04:49 -0400 Subject: [PATCH] Manually fuse broadcast expressions in diag edmf --- .../diagnostic_edmf_precomputed_quantities.jl | 71 +++++++++++-------- src/prognostic_equations/edmfx_entr_detr.jl | 4 +- 2 files changed, 45 insertions(+), 30 deletions(-) diff --git a/src/cache/diagnostic_edmf_precomputed_quantities.jl b/src/cache/diagnostic_edmf_precomputed_quantities.jl index cd03f0e52d8..d35a025c144 100644 --- a/src/cache/diagnostic_edmf_precomputed_quantities.jl +++ b/src/cache/diagnostic_edmf_precomputed_quantities.jl @@ -260,6 +260,32 @@ function compute_u³ʲ_u³ʲ( return u³ʲ_u³ʲ end +function compute_ρaʲu³ʲ( + J_halflevel, + J_prev_level, + J_prev_halflevel, + ρaʲ_prev_level, + entrʲ_prev_level, + detrʲ_prev_level, + u³ʲ_data_prev_halflevel, + S_q_totʲ_prev_level, + precip_model, +) + + ρaʲu³ʲ_data = + (1 / J_halflevel) * + (J_prev_halflevel * ρaʲ_prev_level * u³ʲ_data_prev_halflevel) + + ρaʲu³ʲ_data += + (1 / J_halflevel) * + (J_prev_level * ρaʲ_prev_level * (entrʲ_prev_level - detrʲ_prev_level)) + if precip_model isa Union{Microphysics0Moment, Microphysics1Moment} + ρaʲu³ʲ_data += + (1 / J_halflevel) * + (J_prev_level * ρaʲ_prev_level * S_q_totʲ_prev_level) + end + return ρaʲu³ʲ_data +end NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!( Y, @@ -410,10 +436,13 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!( scale_height = CAP.R_d(params) * CAP.T_surf_ref(params) / CAP.grav(params) - if precip_model isa Union{Microphysics0Moment, Microphysics1Moment} - S_q_totʲ_prev_level = + S_q_totʲ_prev_level = + if precip_model isa + Union{Microphysics0Moment, Microphysics1Moment} Fields.field_values(Fields.level(ᶜS_q_totʲ, i - 1)) - end + else + () + end if precip_model isa Microphysics1Moment S_q_raiʲ_prev_level = Fields.field_values(Fields.level(ᶜS_q_raiʲ, i - 1)) @@ -597,37 +626,23 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!( local_geometry_prev_halflevel, ), dz_prev_level, - ) - @. detrʲ_prev_level = limit_detrainment( - detrʲ_prev_level, - draft_area(ρaʲ_prev_level, ρʲ_prev_level), dt, ) ρaʲu³ʲ_data = p.scratch.temp_data_level_2 ρaʲu³ʲ_datamse = ρaʲu³ʲ_dataq_tot = p.scratch.temp_data_level_3 - @. ρaʲu³ʲ_data = - (1 / local_geometry_halflevel.J) * ( - local_geometry_prev_halflevel.J * - ρaʲ_prev_level * - u³ʲ_data_prev_halflevel - ) - - @. ρaʲu³ʲ_data += - (1 / local_geometry_halflevel.J) * ( - local_geometry_prev_level.J * - ρaʲ_prev_level * - (entrʲ_prev_level - detrʲ_prev_level) - ) - if precip_model isa Union{Microphysics0Moment, Microphysics1Moment} - @. ρaʲu³ʲ_data += - (1 / local_geometry_halflevel.J) * ( - local_geometry_prev_level.J * - ρaʲ_prev_level * - S_q_totʲ_prev_level - ) - end + @. ρaʲu³ʲ_data = compute_ρaʲu³ʲ( + local_geometry_halflevel.J, + local_geometry_prev_level.J, + local_geometry_prev_halflevel.J, + ρaʲ_prev_level, + entrʲ_prev_level, + detrʲ_prev_level, + u³ʲ_data_prev_halflevel, + S_q_totʲ_prev_level, + precip_model, + ) @. u³ʲ_halflevel = ifelse( ( diff --git a/src/prognostic_equations/edmfx_entr_detr.jl b/src/prognostic_equations/edmfx_entr_detr.jl index 9c9e5dba864..845a88a9837 100644 --- a/src/prognostic_equations/edmfx_entr_detr.jl +++ b/src/prognostic_equations/edmfx_entr_detr.jl @@ -375,8 +375,8 @@ end limit_entrainment(entr::FT, a, w, dz) where {FT} = max(min(entr, FT(0.9) * w / dz), 0) -limit_detrainment(detr::FT, a, w, dz) where {FT} = - max(min(detr, FT(0.9) * w / dz), 0) +limit_detrainment(detr::FT, a, w, dz, dt) where {FT} = + limit_detrainment(max(min(detr, FT(0.9) * w / dz), 0), a, dt) function limit_turb_entrainment(dyn_entr::FT, turb_entr::FT, w, dz) where {FT} return max(min((FT(0.9) * w / dz) - dyn_entr, turb_entr), 0)