From 5d3e674faff763cf0232cf178aab47050a8492a0 Mon Sep 17 00:00:00 2001 From: Kevin Phan <98072684+ph-kev@users.noreply.github.com> Date: Fri, 3 Jan 2025 12:13:30 -0800 Subject: [PATCH] Replace ClimaAtmos callbacks with CTS callbacks This commit replaces the callbacks to with the callbacks in ClimaTimeSteppers. --- src/callbacks/callback_helpers.jl | 52 ------------------------- src/callbacks/get_callbacks.jl | 63 +++++++++++++++++++++---------- test/callbacks.jl | 60 ++++++++--------------------- 3 files changed, 59 insertions(+), 116 deletions(-) diff --git a/src/callbacks/callback_helpers.jl b/src/callbacks/callback_helpers.jl index 94474ead62..1cf6df7a5f 100644 --- a/src/callbacks/callback_helpers.jl +++ b/src/callbacks/callback_helpers.jl @@ -4,58 +4,6 @@ import SciMLBase ##### Callback helpers ##### -function call_every_n_steps( - f!, - n = 1; - skip_first = false, - call_at_end = false, - condition = nothing, -) - @assert n ≠ Inf "Adding callback that never gets called!" - cond = if isnothing(condition) - previous_step = Ref(0) - (u, t, integrator) -> - (previous_step[] += 1) % n == 0 || - (call_at_end && t == integrator.sol.prob.tspan[2]) - else - condition - end - cb! = AtmosCallback(f!, EveryNSteps(n)) - return SciMLBase.DiscreteCallback( - cond, - cb!; - initialize = (cb, u, t, integrator) -> skip_first || cb!(integrator), - save_positions = (false, false), - ) -end - -function call_every_dt(f!, dt; skip_first = false, call_at_end = false) - cb! = AtmosCallback(f!, EveryΔt(dt)) - @assert dt ≠ Inf "Adding callback that never gets called!" - next_t = Ref{typeof(dt)}() - affect! = function (integrator) - cb!(integrator) - - t = integrator.t - t_end = integrator.sol.prob.tspan[2] - next_t[] = max(t, next_t[] + dt) - if call_at_end - next_t[] = min(next_t[], t_end) - end - end - return SciMLBase.DiscreteCallback( - (u, t, integrator) -> t >= next_t[], - affect!; - initialize = (cb, u, t, integrator) -> begin - skip_first || cb!(integrator) - t_end = integrator.sol.prob.tspan[2] - next_t[] = - (call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt - end, - save_positions = (false, false), - ) -end - callback_from_affect(x::AtmosCallback) = x function callback_from_affect(affect!) for p in propertynames(affect!) diff --git a/src/callbacks/get_callbacks.jl b/src/callbacks/get_callbacks.jl index 334cfb11ba..a3b84226a0 100644 --- a/src/callbacks/get_callbacks.jl +++ b/src/callbacks/get_callbacks.jl @@ -1,3 +1,5 @@ +import ClimaTimeSteppers.Callbacks: EveryXSimulationTime, EveryXSimulationSteps + function get_diagnostics( parsed_args, atmos_model, @@ -253,21 +255,26 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) @info "Checking NaNs in the state every $(check_nan_every) steps" callbacks = ( callbacks..., - call_every_n_steps( - (integrator) -> check_nans(integrator), + EveryXSimulationSteps( + AtmosCallback( + (integrator) -> check_nans(integrator), + EveryNSteps(check_nan_every), + ), check_nan_every, + atinit = true, + save_positions = (false, false), ), ) end + cond = let output_dir = output_dir + (u, t, integrator) -> maybe_graceful_exit(output_dir, integrator) + end callbacks = ( callbacks..., - call_every_n_steps( + SciMLBase.DiscreteCallback( + cond, terminate!; - skip_first = true, - condition = let output_dir = output_dir - (u, t, integrator) -> - maybe_graceful_exit(output_dir, integrator) - end, + save_positions = (false, false), ), ) @@ -292,12 +299,14 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) end if is_distributed(comms_ctx) + n_steps = parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")) callbacks = ( callbacks..., - call_every_n_steps( - gc_func, - parse(Int, get(ENV, "CLIMAATMOS_GC_NSTEPS", "1000")), - skip_first = true, + EveryXSimulationSteps( + AtmosCallback(gc_func, EveryNSteps(n_steps)), + n_steps, + atinit = false, + save_positions = (false, false), ), ) end @@ -305,18 +314,27 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) if parsed_args["check_conservation"] callbacks = ( callbacks..., - call_every_n_steps( - flux_accumulation!; - skip_first = true, + EveryXSimulationSteps( + AtmosCallback(flux_accumulation!, EveryNSteps(1)), + 1; + atinit = false, call_at_end = true, + save_positions = (false, false), ), ) end if !parsed_args["call_cloud_diagnostics_per_stage"] dt_cf = FT(time_to_seconds(parsed_args["dt_cloud_fraction"])) - callbacks = - (callbacks..., call_every_dt(cloud_fraction_model_callback!, dt_cf)) + callbacks = ( + callbacks..., + EveryXSimulationTime( + AtmosCallback(cloud_fraction_model_callback!, EveryΔt(dt_cf)), + dt_cf, + atinit = true, + save_positions = (false, false), + ), + ) end if atmos.radiation_mode isa RRTMGPI.AbstractRRTMGPMode @@ -329,8 +347,15 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) @warn "This simulation will not be reproducible when restarted" end - callbacks = - (callbacks..., call_every_dt(rrtmgp_model_callback!, dt_rad)) + callbacks = ( + callbacks..., + EveryXSimulationTime( + AtmosCallback(rrtmgp_model_callback!, EveryΔt(dt_rad)), + dt_rad, + atinit = true, + save_positions = (false, false), + ), + ) end return callbacks diff --git a/test/callbacks.jl b/test/callbacks.jl index 6eb1afa033..39e57acb19 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -3,63 +3,33 @@ import ClimaComms ClimaComms.@import_required_backends import ClimaAtmos as CA import SciMLBase as SMB +import ClimaTimeSteppers.Callbacks as CB testfun!() = π -cb_default = CA.call_every_n_steps(testfun!;) test_nsteps = 999 test_dt = 1 test_tend = 999.0 - -cb_1 = CA.call_every_n_steps( - testfun!, - test_nsteps; - skip_first = false, +cb_1 = CB.EveryXSimulationSteps( + CA.AtmosCallback(testfun!, CA.EveryNSteps(test_nsteps)), + test_nsteps, + atinit = true, + call_at_end = false, +) +cb_2 = CB.EveryXSimulationTime( + CA.AtmosCallback(testfun!, CA.EveryΔt(test_dt)), + test_dt; + atinit = true, call_at_end = false, - condition = nothing, ) -cb_2 = - CA.call_every_dt(testfun!, test_dt; skip_first = false, call_at_end = false) cb_3 = CA.callback_from_affect(cb_2.affect!) -cb_4 = CA.call_every_n_steps( - testfun!, - 3; - skip_first = false, +cb_4 = CB.EveryXSimulationSteps( + CA.AtmosCallback(testfun!, CA.EveryNSteps(3)), + 3, + atinit = true, call_at_end = false, - condition = nothing, ) cb_set = SMB.CallbackSet(cb_1, cb_2, cb_4) -@testset "simple default callback" begin - @test cb_default.condition.n == 1 - @test cb_default.affect!.f!() == π -end - -# per n steps -@testset "every n-steps callback" begin - @test cb_1.initialize.skip_first == false - @test cb_1.condition.n == test_nsteps - @test cb_1.affect!.f!() == π - @test_throws AssertionError CA.call_every_n_steps( - testfun!, - Inf; - skip_first = false, - call_at_end = false, - ) -end - -# per dt interval -@testset "dt interval callback" begin - @test cb_2 isa SMB.DiscreteCallback - @test cb_2.affect!.dt == test_dt - @test cb_2.affect!.cb!.f!() == π - @test_throws AssertionError CA.call_every_dt( - testfun!, - Inf; - skip_first = false, - call_at_end = false, - ) -end - @testset "atmos callbacks and callback sets" begin # atmoscallbacks from discrete callbacks @test cb_3.f!() == π