Skip to content

Commit

Permalink
Manually fuse broadcast expressions in diag edmf
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 8, 2024
1 parent 59fd53d commit d5533ad
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 30 deletions.
71 changes: 43 additions & 28 deletions src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,32 @@ function compute_u³ʲ_u³ʲ(
return u³ʲ_u³ʲ
end

function compute_ρaʲu³ʲ(
J_halflevel,
J_prev_level,
J_prev_halflevel,
ρaʲ_prev_level,
entrʲ_prev_level,
detrʲ_prev_level,
u³ʲ_data_prev_halflevel,
S_q_totʲ_prev_level,
precip_model,
)

ρaʲu³ʲ_data =
(1 / J_halflevel) *
(J_prev_halflevel * ρaʲ_prev_level * u³ʲ_data_prev_halflevel)

ρaʲu³ʲ_data +=
(1 / J_halflevel) *
(J_prev_level * ρaʲ_prev_level * (entrʲ_prev_level - detrʲ_prev_level))
if precip_model isa Union{Microphysics0Moment, Microphysics1Moment}
ρaʲu³ʲ_data +=
(1 / J_halflevel) *
(J_prev_level * ρaʲ_prev_level * S_q_totʲ_prev_level)
end
return ρaʲu³ʲ_data
end

NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
Y,
Expand Down Expand Up @@ -410,10 +436,13 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
scale_height =
CAP.R_d(params) * CAP.T_surf_ref(params) / CAP.grav(params)

if precip_model isa Union{Microphysics0Moment, Microphysics1Moment}
S_q_totʲ_prev_level =
S_q_totʲ_prev_level =
if precip_model isa
Union{Microphysics0Moment, Microphysics1Moment}
Fields.field_values(Fields.level(ᶜS_q_totʲ, i - 1))
end
else
()
end
if precip_model isa Microphysics1Moment
S_q_raiʲ_prev_level =
Fields.field_values(Fields.level(ᶜS_q_raiʲ, i - 1))
Expand Down Expand Up @@ -597,37 +626,23 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
local_geometry_prev_halflevel,
),
dz_prev_level,
)
@. detrʲ_prev_level = limit_detrainment(
detrʲ_prev_level,
draft_area(ρaʲ_prev_level, ρʲ_prev_level),
dt,
)

ρaʲu³ʲ_data = p.scratch.temp_data_level_2
ρaʲu³ʲ_datamse = ρaʲu³ʲ_dataq_tot = p.scratch.temp_data_level_3

@. ρaʲu³ʲ_data =
(1 / local_geometry_halflevel.J) * (
local_geometry_prev_halflevel.J *
ρaʲ_prev_level *
u³ʲ_data_prev_halflevel
)

@. ρaʲu³ʲ_data +=
(1 / local_geometry_halflevel.J) * (
local_geometry_prev_level.J *
ρaʲ_prev_level *
(entrʲ_prev_level - detrʲ_prev_level)
)
if precip_model isa Union{Microphysics0Moment, Microphysics1Moment}
@. ρaʲu³ʲ_data +=
(1 / local_geometry_halflevel.J) * (
local_geometry_prev_level.J *
ρaʲ_prev_level *
S_q_totʲ_prev_level
)
end
@. ρaʲu³ʲ_data = compute_ρaʲu³ʲ(
local_geometry_halflevel.J,
local_geometry_prev_level.J,
local_geometry_prev_halflevel.J,
ρaʲ_prev_level,
entrʲ_prev_level,
detrʲ_prev_level,
u³ʲ_data_prev_halflevel,
S_q_totʲ_prev_level,
precip_model,
)

@. u³ʲ_halflevel = ifelse(
(
Expand Down
4 changes: 2 additions & 2 deletions src/prognostic_equations/edmfx_entr_detr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ end
limit_entrainment(entr::FT, a, w, dz) where {FT} =
max(min(entr, FT(0.9) * w / dz), 0)

limit_detrainment(detr::FT, a, w, dz) where {FT} =
max(min(detr, FT(0.9) * w / dz), 0)
limit_detrainment(detr::FT, a, w, dz, dt) where {FT} =
limit_detrainment(max(min(detr, FT(0.9) * w / dz), 0), a, dt)

function limit_turb_entrainment(dyn_entr::FT, turb_entr::FT, w, dz) where {FT}
return max(min((FT(0.9) * w / dz) - dyn_entr, turb_entr), 0)
Expand Down

0 comments on commit d5533ad

Please sign in to comment.