Skip to content

Commit

Permalink
Replace ClimaAtmos callbacks with CTS callbacks
Browse files Browse the repository at this point in the history
This commit replaces the callbacks to with the callbacks in
ClimaTimeSteppers.
  • Loading branch information
ph-kev committed Jan 6, 2025
1 parent 3a38c4f commit 5d3e674
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 116 deletions.
52 changes: 0 additions & 52 deletions src/callbacks/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,6 @@ import SciMLBase
##### Callback helpers
#####

function call_every_n_steps(
f!,
n = 1;
skip_first = false,
call_at_end = false,
condition = nothing,
)
@assert n Inf "Adding callback that never gets called!"
cond = if isnothing(condition)
previous_step = Ref(0)
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2])
else
condition
end
cb! = AtmosCallback(f!, EveryNSteps(n))
return SciMLBase.DiscreteCallback(
cond,
cb!;
initialize = (cb, u, t, integrator) -> skip_first || cb!(integrator),
save_positions = (false, false),
)
end

function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
cb! = AtmosCallback(f!, EveryΔt(dt))
@assert dt Inf "Adding callback that never gets called!"
next_t = Ref{typeof(dt)}()
affect! = function (integrator)
cb!(integrator)

t = integrator.t
t_end = integrator.sol.prob.tspan[2]
next_t[] = max(t, next_t[] + dt)
if call_at_end
next_t[] = min(next_t[], t_end)
end
end
return SciMLBase.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
skip_first || cb!(integrator)
t_end = integrator.sol.prob.tspan[2]
next_t[] =
(call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt
end,
save_positions = (false, false),
)
end

callback_from_affect(x::AtmosCallback) = x
function callback_from_affect(affect!)
for p in propertynames(affect!)
Expand Down
63 changes: 44 additions & 19 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ClimaTimeSteppers.Callbacks: EveryXSimulationTime, EveryXSimulationSteps

function get_diagnostics(
parsed_args,
atmos_model,
Expand Down Expand Up @@ -253,21 +255,26 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
@info "Checking NaNs in the state every $(check_nan_every) steps"
callbacks = (
callbacks...,
call_every_n_steps(
(integrator) -> check_nans(integrator),
EveryXSimulationSteps(
AtmosCallback(
(integrator) -> check_nans(integrator),
EveryNSteps(check_nan_every),
),
check_nan_every,
atinit = true,
save_positions = (false, false),
),
)
end
cond = let output_dir = output_dir
(u, t, integrator) -> maybe_graceful_exit(output_dir, integrator)
end
callbacks = (
callbacks...,
call_every_n_steps(
SciMLBase.DiscreteCallback(
cond,
terminate!;
skip_first = true,
condition = let output_dir = output_dir
(u, t, integrator) ->
maybe_graceful_exit(output_dir, integrator)
end,
save_positions = (false, false),
),
)

Expand All @@ -292,31 +299,42 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
end

if is_distributed(comms_ctx)
n_steps = parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000"))
callbacks = (
callbacks...,
call_every_n_steps(
gc_func,
parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")),
skip_first = true,
EveryXSimulationSteps(
AtmosCallback(gc_func, EveryNSteps(n_steps)),
n_steps,
atinit = false,
save_positions = (false, false),
),
)
end

if parsed_args["check_conservation"]
callbacks = (
callbacks...,
call_every_n_steps(
flux_accumulation!;
skip_first = true,
EveryXSimulationSteps(
AtmosCallback(flux_accumulation!, EveryNSteps(1)),
1;
atinit = false,
call_at_end = true,
save_positions = (false, false),
),
)
end

if !parsed_args["call_cloud_diagnostics_per_stage"]
dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"]))
callbacks =
(callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf))
callbacks = (
callbacks...,
EveryXSimulationTime(
AtmosCallback(cloud_fraction_model_callback!, EveryΔt(dt_cf)),
dt_cf,
atinit = true,
save_positions = (false, false),
),
)
end

if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode
Expand All @@ -329,8 +347,15 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
@warn "This simulation will not be reproducible when restarted"
end

callbacks =
(callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad))
callbacks = (
callbacks...,
EveryXSimulationTime(
AtmosCallback(rrtmgp_model_callback!, EveryΔt(dt_rad)),
dt_rad,
atinit = true,
save_positions = (false, false),
),
)
end

return callbacks
Expand Down
60 changes: 15 additions & 45 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,33 @@ import ClimaComms
ClimaComms.@import_required_backends
import ClimaAtmos as CA
import SciMLBase as SMB
import ClimaTimeSteppers.Callbacks as CB

testfun!() = π
cb_default = CA.call_every_n_steps(testfun!;)
test_nsteps = 999
test_dt = 1
test_tend = 999.0

cb_1 = CA.call_every_n_steps(
testfun!,
test_nsteps;
skip_first = false,
cb_1 = CB.EveryXSimulationSteps(
CA.AtmosCallback(testfun!, CA.EveryNSteps(test_nsteps)),
test_nsteps,
atinit = true,
call_at_end = false,
)
cb_2 = CB.EveryXSimulationTime(
CA.AtmosCallback(testfun!, CA.EveryΔt(test_dt)),
test_dt;
atinit = true,
call_at_end = false,
condition = nothing,
)
cb_2 =
CA.call_every_dt(testfun!, test_dt; skip_first = false, call_at_end = false)
cb_3 = CA.callback_from_affect(cb_2.affect!)
cb_4 = CA.call_every_n_steps(
testfun!,
3;
skip_first = false,
cb_4 = CB.EveryXSimulationSteps(
CA.AtmosCallback(testfun!, CA.EveryNSteps(3)),
3,
atinit = true,
call_at_end = false,
condition = nothing,
)
cb_set = SMB.CallbackSet(cb_1, cb_2, cb_4)

@testset "simple default callback" begin
@test cb_default.condition.n == 1
@test cb_default.affect!.f!() == π
end

# per n steps
@testset "every n-steps callback" begin
@test cb_1.initialize.skip_first == false
@test cb_1.condition.n == test_nsteps
@test cb_1.affect!.f!() == π
@test_throws AssertionError CA.call_every_n_steps(
testfun!,
Inf;
skip_first = false,
call_at_end = false,
)
end

# per dt interval
@testset "dt interval callback" begin
@test cb_2 isa SMB.DiscreteCallback
@test cb_2.affect!.dt == test_dt
@test cb_2.affect!.cb!.f!() == π
@test_throws AssertionError CA.call_every_dt(
testfun!,
Inf;
skip_first = false,
call_at_end = false,
)
end

@testset "atmos callbacks and callback sets" begin
# atmoscallbacks from discrete callbacks
@test cb_3.f!() == π
Expand Down

0 comments on commit 5d3e674

Please sign in to comment.