Skip to content

Commit

Permalink
Add check for accumulated diagnostics and restarts
Browse files Browse the repository at this point in the history
This commit adds a warning when we can detect checkpointing and
diagnostics being out of sync. ClimaAtmos now detects what frequencies
are involved and print them out, e.g.
```
[ Info: Saving accumulated diagnostics to disk with frequency: 2 days, 1 day
```
Then, it checks that all these frequencies evenly divide the
checkpointing frequency. If this does not happen, a warning is produced.
  • Loading branch information
Sbozzolo committed Oct 23, 2024
1 parent 5841b90 commit ea8042d
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 43 deletions.
15 changes: 15 additions & 0 deletions docs/src/restarts.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,18 @@ is started.
If is also possible to manually specify a restart file. In this case, this will
override any file automatically detected.

### Accumulated Diagnostics

At the moment, `ClimaAtmos` does not support working accumulated diagnostics
across restarts. The present limitations are best illustrated with an example.

Suppose you are saving 30-day averages and stop the simulation at day 45. If you
do so, you'll find output for day 30 and the checkpoint at day 45. Then, if you
restart the simulation, you'll see that the next diagnostic output will be at
day 75, and not day 60. In other words, the counter starts from 0 with every
restart.

!!! note

If you care about accurate accumulated diagnostics, make sure to line up your
checkpoint and diagnostic frequencies.
89 changes: 54 additions & 35 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,21 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
period_dates = Dates.Month(parse(Int, first(months)))
output_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
compute_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
else
period_seconds = FT(time_to_seconds(period_str))
output_schedule = CAD.EveryDtSchedule(period_seconds)
compute_schedule = CAD.EveryDtSchedule(period_seconds)
period_dates =
CA.promote_period.(Dates.Second(period_seconds))
end

output_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
compute_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)

if isnothing(output_name)
output_short_name = CAD.descriptive_short_name(
CAD.get_diagnostic_variable(short_name),
Expand Down Expand Up @@ -159,6 +160,25 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
end
diagnostics = collect(diagnostics)

periods_reductions = Set()
for diag in diagnostics
isa_reduction = !isnothing(diag.reduction_time_func)
isa_reduction || continue

if diag.output_schedule_func isa CAD.EveryDtSchedule
period = Dates.Second(diag.output_schedule_func.dt)
elseif diag.output_schedule_func isa CAD.EveryCalendarDtSchedule
period = diag.output_schedule_func.dt
else
continue
end

push!(periods_reductions, period)
end

periods_str = join(CA.promote_period.(periods_reductions), ", ")
@info "Saving accumulated diagnostics to disk with frequency: $(periods_str)"

for writer in writers
writer_str = nameof(typeof(writer))
diags_with_writer =
Expand All @@ -169,9 +189,28 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
@info "$writer_str: $diags_outputs"
end

return diagnostics, writers
return diagnostics, writers, periods_reductions
end

function checkpoint_frequency_from_parsed_args(dt_save_state_to_disk)
if occursin("months", dt_save_state_to_disk)
months = match(r"^(\d+)months$", dt_save_state_to_disk)
isnothing(months) && error(
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
return Dates.Month(parse(Int, first(months)))
else
dt_save_state_to_disk = time_to_seconds(dt_save_state_to_disk)
if !(dt_save_state_to_disk == Inf)
# We use Millisecond to support fractional seconds, eg. 0.1
return Dates.Millisecond(1000dt_save_state_to_disk)
else
return Inf
end
end
end


function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
(; parsed_args, comms_ctx) = config
FT = eltype(params)
Expand Down Expand Up @@ -211,13 +250,10 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)

# Save dt_save_state_to_disk as a Dates.Period object. This is used to check
# if it is an integer multiple of other frequencies.
dt_save_state_to_disk_dates = Dates.today() # Value will be overwritten
if occursin("months", parsed_args["dt_save_state_to_disk"])
months = match(r"^(\d+)months$", parsed_args["dt_save_state_to_disk"])
isnothing(months) && error(
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
dt_save_state_to_disk_dates = Dates.Month(parse(Int, first(months)))
dt_save_state_to_disk_dates = checkpoint_frequency_from_parsed_args(
parsed_args["dt_save_state_to_disk"],
)
if dt_save_state_to_disk_dates != Inf
schedule = CAD.EveryCalendarDtSchedule(
dt_save_state_to_disk_dates;
reference_date = p.start_date,
Expand All @@ -230,23 +266,6 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
(integrator) -> save_state_to_disk_func(integrator, output_dir)
end
callbacks = (callbacks..., SciMLBase.DiscreteCallback(cond, affect!))
else
dt_save_state_to_disk =
time_to_seconds(parsed_args["dt_save_state_to_disk"])
if !(dt_save_state_to_disk == Inf)
# We use Millisecond to support fractional seconds, eg. 0.1
dt_save_state_to_disk_dates =
Dates.Millisecond(dt_save_state_to_disk)
callbacks = (
callbacks...,
call_every_dt(
(integrator) ->
save_state_to_disk_func(integrator, output_dir),
dt_save_state_to_disk;
skip_first = sim_info.restart,
),
)
end
end

if is_distributed(comms_ctx)
Expand Down
31 changes: 23 additions & 8 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,16 +757,31 @@ function get_simulation(config::AtmosConfig)
# Initialize diagnostics
if config.parsed_args["enable_diagnostics"]
s = @timed_str begin
scheduled_diagnostics, writers = get_diagnostics(
config.parsed_args,
atmos,
Y,
p,
sim_info.dt,
t_start,
)
scheduled_diagnostics, writers, periods_reductions =
get_diagnostics(
config.parsed_args,
atmos,
Y,
p,
sim_info.dt,
t_start,
)
end
@info "initializing diagnostics: $s"

# Check for consistency between diagnostics and checkpoints
checkpoint_frequency = checkpoint_frequency_from_parsed_args(
config.parsed_args["dt_save_state_to_disk"],
)

if checkpoint_frequency != Inf
if any(
x -> !CA.isdivisible(checkpoint_frequency, x),
periods_reductions,
)
@warn "Some accumulated diagnostics might not be evenly divisible by the checkpointing frequency ($(CA.promote_period(checkpoint_frequency)))"
end
end
else
writers = nothing
end
Expand Down
42 changes: 42 additions & 0 deletions src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,45 @@ function isdivisible(
# have any common divisor)
return isinteger(Dates.Day(1) / dt_small)
end

"""
promote_period(period::Dates.Period)
Promote a period to the largest possible period type.
This function attempts to represent a given `Period` using the largest possible
unit of time. For example, a period of 24 hours will be promoted to 1 day.
# Examples
```julia-repl
julia> promote_period(Hour(24))
1 day
julia> promote_period(Day(14))
2 weeks
julia> promote_period(Day(365))
1 year
julia> promote_period(Millisecond(1))
1 millisecond
```
"""
function promote_period(period::Dates.Period)
ms = Int(Dates.toms(period))
# Hard to do this with varying periods like Month/Year...
PeriodTypes = [
Dates.Week,
Dates.Day,
Dates.Hour,
Dates.Minute,
Dates.Second,
Dates.Millisecond,
]
for PeriodType in PeriodTypes
period_ms = Int(Dates.toms(PeriodType(1)))
if ms % period_ms == 0
return PeriodType(ms // period_ms)
end
end
end
8 changes: 8 additions & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,11 @@ end
@test test_cent_space == cent_space
@test test_face_space == face_space
end

@testset "promote_period" begin
@test CA.promote_period(Dates.Hour(24)) == Dates.Day(1)
@test CA.promote_period(Dates.Day(14)) == Dates.Week(2)
@test CA.promote_period(Dates.Millisecond(1)) == Dates.Millisecond(1)
@test CA.promote_period(Dates.Minute(120)) == Dates.Hour(2)
@test CA.promote_period(Dates.Second(3600)) == Dates.Hour(1)
end

0 comments on commit ea8042d

Please sign in to comment.