From 8bb56be940478f8d23f053e853dcf29227c6297b Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Wed, 23 Oct 2024 09:24:59 -0700 Subject: [PATCH] Add check for accumulated diagnostics and restarts This commit adds a warning when we can detect checkpointing and diagnostics being out of sync. ClimaAtmos now detects what frequencies are involved and print them out, e.g. ``` [ Info: Saving accumulated diagnostics to disk with frequency: 2 days, 1 day ``` Then, it checks that all these frequencies evenly divide the checkpointing frequency. If this does not happen, a warning is produced. --- docs/src/restarts.md | 15 ++++++ src/callbacks/get_callbacks.jl | 89 +++++++++++++++++++++------------- src/solver/type_getters.jl | 31 +++++++++--- src/utils/utilities.jl | 43 ++++++++++++++++ test/utilities.jl | 8 +++ 5 files changed, 143 insertions(+), 43 deletions(-) diff --git a/docs/src/restarts.md b/docs/src/restarts.md index 01108943571..7d96022a130 100644 --- a/docs/src/restarts.md +++ b/docs/src/restarts.md @@ -47,3 +47,18 @@ is started. If is also possible to manually specify a restart file. In this case, this will override any file automatically detected. +### Accumulated Diagnostics + +At the moment, `ClimaAtmos` does not support working accumulated diagnostics +across restarts. The present limitations are best illustrated with an example. + +Suppose you are saving 30-day averages and stop the simulation at day 45. If you +do so, you'll find output for day 30 and the checkpoint at day 45. Then, if you +restart the simulation, you'll see that the next diagnostic output will be at +day 75, and not day 60. In other words, the counter starts from 0 with every +restart. + +!!! note + + If you care about accurate accumulated diagnostics, make sure to line up your + checkpoint and diagnostic frequencies. diff --git a/src/callbacks/get_callbacks.jl b/src/callbacks/get_callbacks.jl index 076687c62e9..fb8f6f36a7d 100644 --- a/src/callbacks/get_callbacks.jl +++ b/src/callbacks/get_callbacks.jl @@ -102,20 +102,21 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start) "$(period_str) has to be of the form months, e.g. 2months for 2 months", ) period_dates = Dates.Month(parse(Int, first(months))) - output_schedule = CAD.EveryCalendarDtSchedule( - period_dates; - reference_date = p.start_date, - ) - compute_schedule = CAD.EveryCalendarDtSchedule( - period_dates; - reference_date = p.start_date, - ) else period_seconds = FT(time_to_seconds(period_str)) - output_schedule = CAD.EveryDtSchedule(period_seconds) - compute_schedule = CAD.EveryDtSchedule(period_seconds) + period_dates = + CA.promote_period.(Dates.Second(period_seconds)) end + output_schedule = CAD.EveryCalendarDtSchedule( + period_dates; + reference_date = p.start_date, + ) + compute_schedule = CAD.EveryCalendarDtSchedule( + period_dates; + reference_date = p.start_date, + ) + if isnothing(output_name) output_short_name = CAD.descriptive_short_name( CAD.get_diagnostic_variable(short_name), @@ -159,6 +160,25 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start) end diagnostics = collect(diagnostics) + periods_reductions = Set() + for diag in diagnostics + isa_reduction = !isnothing(diag.reduction_time_func) + isa_reduction || continue + + if diag.output_schedule_func isa CAD.EveryDtSchedule + period = Dates.Second(diag.output_schedule_func.dt) + elseif diag.output_schedule_func isa CAD.EveryCalendarDtSchedule + period = diag.output_schedule_func.dt + else + continue + end + + push!(periods_reductions, period) + end + + periods_str = join(CA.promote_period.(periods_reductions), ", ") + @info "Saving accumulated diagnostics to disk with frequency: $(periods_str)" + for writer in writers writer_str = nameof(typeof(writer)) diags_with_writer = @@ -169,9 +189,28 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start) @info "$writer_str: $diags_outputs" end - return diagnostics, writers + return diagnostics, writers, periods_reductions end +function checkpoint_frequency_from_parsed_args(dt_save_state_to_disk::String) + if occursin("months", dt_save_state_to_disk) + months = match(r"^(\d+)months$", dt_save_state_to_disk) + isnothing(months) && error( + "$(period_str) has to be of the form months, e.g. 2months for 2 months", + ) + return Dates.Month(parse(Int, first(months))) + else + dt_save_state_to_disk = time_to_seconds(dt_save_state_to_disk) + if !(dt_save_state_to_disk == Inf) + # We use Millisecond to support fractional seconds, eg. 0.1 + return Dates.Millisecond(1000dt_save_state_to_disk) + else + return Inf + end + end +end + + function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) (; parsed_args, comms_ctx) = config FT = eltype(params) @@ -211,13 +250,10 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) # Save dt_save_state_to_disk as a Dates.Period object. This is used to check # if it is an integer multiple of other frequencies. - dt_save_state_to_disk_dates = Dates.today() # Value will be overwritten - if occursin("months", parsed_args["dt_save_state_to_disk"]) - months = match(r"^(\d+)months$", parsed_args["dt_save_state_to_disk"]) - isnothing(months) && error( - "$(period_str) has to be of the form months, e.g. 2months for 2 months", - ) - dt_save_state_to_disk_dates = Dates.Month(parse(Int, first(months))) + dt_save_state_to_disk_dates = checkpoint_frequency_from_parsed_args( + parsed_args["dt_save_state_to_disk"], + ) + if dt_save_state_to_disk_dates != Inf schedule = CAD.EveryCalendarDtSchedule( dt_save_state_to_disk_dates; reference_date = p.start_date, @@ -230,23 +266,6 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start) (integrator) -> save_state_to_disk_func(integrator, output_dir) end callbacks = (callbacks..., SciMLBase.DiscreteCallback(cond, affect!)) - else - dt_save_state_to_disk = - time_to_seconds(parsed_args["dt_save_state_to_disk"]) - if !(dt_save_state_to_disk == Inf) - # We use Millisecond to support fractional seconds, eg. 0.1 - dt_save_state_to_disk_dates = - Dates.Millisecond(dt_save_state_to_disk) - callbacks = ( - callbacks..., - call_every_dt( - (integrator) -> - save_state_to_disk_func(integrator, output_dir), - dt_save_state_to_disk; - skip_first = sim_info.restart, - ), - ) - end end if is_distributed(comms_ctx) diff --git a/src/solver/type_getters.jl b/src/solver/type_getters.jl index b267cb5d1d5..0a8fdacad0f 100644 --- a/src/solver/type_getters.jl +++ b/src/solver/type_getters.jl @@ -757,16 +757,31 @@ function get_simulation(config::AtmosConfig) # Initialize diagnostics if config.parsed_args["enable_diagnostics"] s = @timed_str begin - scheduled_diagnostics, writers = get_diagnostics( - config.parsed_args, - atmos, - Y, - p, - sim_info.dt, - t_start, - ) + scheduled_diagnostics, writers, periods_reductions = + get_diagnostics( + config.parsed_args, + atmos, + Y, + p, + sim_info.dt, + t_start, + ) end @info "initializing diagnostics: $s" + + # Check for consistency between diagnostics and checkpoints + checkpoint_frequency = checkpoint_frequency_from_parsed_args( + config.parsed_args["dt_save_state_to_disk"], + ) + + if checkpoint_frequency != Inf + if any( + x -> !CA.isdivisible(checkpoint_frequency, x), + periods_reductions, + ) + @warn "Some accumulated diagnostics might not be evenly divisible by the checkpointing frequency ($(CA.promote_period(checkpoint_frequency)))" + end + end else writers = nothing end diff --git a/src/utils/utilities.jl b/src/utils/utilities.jl index fffbc4a7b21..6e3047672f5 100644 --- a/src/utils/utilities.jl +++ b/src/utils/utilities.jl @@ -454,3 +454,46 @@ function isdivisible( # have any common divisor) return isinteger(Dates.Day(1) / dt_small) end + +""" + promote_period(period::Dates.Period) + +Promote a period to the largest possible period type. + +This function attempts to represent a given `Period` using the largest possible +unit of time. For example, a period of 24 hours will be promoted to 1 day. If +a clean promotion is not possible, return the input as it is. + +# Examples +```julia-repl +julia> promote_period(Hour(24)) +1 day + +julia> promote_period(Day(14)) +2 weeks + +julia> promote_period(Second(86401)) +86401 seconds + +julia> promote_period(Millisecond(1)) +1 millisecond +``` +""" +function promote_period(period::Dates.Period) + ms = Int(Dates.toms(period)) + # Hard to do this with varying periods like Month/Year... + PeriodTypes = [ + Dates.Week, + Dates.Day, + Dates.Hour, + Dates.Minute, + Dates.Second, + Dates.Millisecond, + ] + for PeriodType in PeriodTypes + period_ms = Int(Dates.toms(PeriodType(1))) + if ms % period_ms == 0 + return PeriodType(ms // period_ms) + end + end +end diff --git a/test/utilities.jl b/test/utilities.jl index c7247ce5a39..af51a623c9c 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -266,3 +266,11 @@ end @test test_cent_space == cent_space @test test_face_space == face_space end + +@testset "promote_period" begin + @test CA.promote_period(Dates.Hour(24)) == Dates.Day(1) + @test CA.promote_period(Dates.Day(14)) == Dates.Week(2) + @test CA.promote_period(Dates.Millisecond(1)) == Dates.Millisecond(1) + @test CA.promote_period(Dates.Minute(120)) == Dates.Hour(2) + @test CA.promote_period(Dates.Second(3600)) == Dates.Hour(1) +end