Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ClimaAtmos callbacks with CTS callbacks #3514

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,6 @@ import SciMLBase
##### Callback helpers
#####

function call_every_n_steps(
f!,
n = 1;
skip_first = false,
call_at_end = false,
condition = nothing,
)
@assert n ≠ Inf "Adding callback that never gets called!"
cond = if isnothing(condition)
previous_step = Ref(0)
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2])
else
condition
end
cb! = AtmosCallback(f!, EveryNSteps(n))
return SciMLBase.DiscreteCallback(
cond,
cb!;
initialize = (cb, u, t, integrator) -> skip_first || cb!(integrator),
save_positions = (false, false),
)
end

function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
cb! = AtmosCallback(f!, EveryΔt(dt))
@assert dt ≠ Inf "Adding callback that never gets called!"
next_t = Ref{typeof(dt)}()
affect! = function (integrator)
cb!(integrator)

t = integrator.t
t_end = integrator.sol.prob.tspan[2]
next_t[] = max(t, next_t[] + dt)
if call_at_end
next_t[] = min(next_t[], t_end)
end
end
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
skip_first || cb!(integrator)
t_end = integrator.sol.prob.tspan[2]
next_t[] =
(call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt
end,
save_positions = (false, false),
)
end

callback_from_affect(x::AtmosCallback) = x
function callback_from_affect(affect!)
for p in propertynames(affect!)
Expand Down
63 changes: 44 additions & 19 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ClimaTimeSteppers.Callbacks: EveryXSimulationTime, EveryXSimulationSteps

function get_diagnostics(
parsed_args,
atmos_model,
Expand Down Expand Up @@ -253,21 +255,26 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
@info "Checking NaNs in the state every $(check_nan_every) steps"
callbacks = (
callbacks...,
call_every_n_steps(
(integrator) -> check_nans(integrator),
EveryXSimulationSteps(
AtmosCallback(
(integrator) -> check_nans(integrator),
EveryNSteps(check_nan_every),
),
check_nan_every,
atinit = true,
save_positions = (false, false),
),
)
end
cond = let output_dir = output_dir
(u, t, integrator) -> maybe_graceful_exit(output_dir, integrator)
end
callbacks = (
callbacks...,
call_every_n_steps(
SciMLBase.DiscreteCallback(
cond,
terminate!;
skip_first = true,
condition = let output_dir = output_dir
(u, t, integrator) ->
maybe_graceful_exit(output_dir, integrator)
end,
save_positions = (false, false),
),
)

Expand All @@ -292,31 +299,42 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
end

if is_distributed(comms_ctx)
n_steps = parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000"))
callbacks = (
callbacks...,
call_every_n_steps(
gc_func,
parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")),
skip_first = true,
EveryXSimulationSteps(
AtmosCallback(gc_func, EveryNSteps(n_steps)),
n_steps,
atinit = false,
save_positions = (false, false),
),
)
end

if parsed_args["check_conservation"]
callbacks = (
callbacks...,
call_every_n_steps(
flux_accumulation!;
skip_first = true,
EveryXSimulationSteps(
AtmosCallback(flux_accumulation!, EveryNSteps(1)),
1;
atinit = false,
call_at_end = true,
save_positions = (false, false),
),
)
end

if !parsed_args["call_cloud_diagnostics_per_stage"]
dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"]))
callbacks =
(callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf))
callbacks = (
callbacks...,
EveryXSimulationTime(
AtmosCallback(cloud_fraction_model_callback!, EveryΔt(dt_cf)),
dt_cf,
atinit = true,
save_positions = (false, false),
),
)
end

if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode
Expand All @@ -329,8 +347,15 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
@warn "This simulation will not be reproducible when restarted"
end

callbacks =
(callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad))
callbacks = (
callbacks...,
EveryXSimulationTime(
AtmosCallback(rrtmgp_model_callback!, EveryΔt(dt_rad)),
dt_rad,
atinit = true,
save_positions = (false, false),
),
)
end

return callbacks
Expand Down
60 changes: 15 additions & 45 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,33 @@ import ClimaComms
ClimaComms.@import_required_backends
import ClimaAtmos as CA
import SciMLBase as SMB
import ClimaTimeSteppers.Callbacks as CB

testfun!() = π
cb_default = CA.call_every_n_steps(testfun!;)
test_nsteps = 999
test_dt = 1
test_tend = 999.0

cb_1 = CA.call_every_n_steps(
testfun!,
test_nsteps;
skip_first = false,
cb_1 = CB.EveryXSimulationSteps(
CA.AtmosCallback(testfun!, CA.EveryNSteps(test_nsteps)),
test_nsteps,
atinit = true,
call_at_end = false,
)
cb_2 = CB.EveryXSimulationTime(
CA.AtmosCallback(testfun!, CA.EveryΔt(test_dt)),
test_dt;
atinit = true,
call_at_end = false,
condition = nothing,
)
cb_2 =
CA.call_every_dt(testfun!, test_dt; skip_first = false, call_at_end = false)
cb_3 = CA.callback_from_affect(cb_2.affect!)
cb_4 = CA.call_every_n_steps(
testfun!,
3;
skip_first = false,
cb_4 = CB.EveryXSimulationSteps(
CA.AtmosCallback(testfun!, CA.EveryNSteps(3)),
3,
atinit = true,
call_at_end = false,
condition = nothing,
)
cb_set = SMB.CallbackSet(cb_1, cb_2, cb_4)

@testset "simple default callback" begin
@test cb_default.condition.n == 1
@test cb_default.affect!.f!() == π
end

# per n steps
@testset "every n-steps callback" begin
@test cb_1.initialize.skip_first == false
@test cb_1.condition.n == test_nsteps
@test cb_1.affect!.f!() == π
@test_throws AssertionError CA.call_every_n_steps(
testfun!,
Inf;
skip_first = false,
call_at_end = false,
)
end

# per dt interval
@testset "dt interval callback" begin
@test cb_2 isa SMB.DiscreteCallback
@test cb_2.affect!.dt == test_dt
@test cb_2.affect!.cb!.f!() == π
@test_throws AssertionError CA.call_every_dt(
testfun!,
Inf;
skip_first = false,
call_at_end = false,
)
end

@testset "atmos callbacks and callback sets" begin
# atmoscallbacks from discrete callbacks
@test cb_3.f!() == π
Expand Down
Loading