Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check for accumulated diagnostics and restarts #3396

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/src/restarts.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,19 @@ is started.
If is also possible to manually specify a restart file. In this case, this will
override any file automatically detected.

### Accumulated Diagnostics

At the moment, `ClimaAtmos` does not support working with accumulated
diagnostics across restarts. The present limitations are best illustrated with
an example.

Suppose you are saving 30-day averages and stop the simulation at day 45. If you
do so, you'll find output for day 30 and the checkpoint at day 45. Then, if you
restart the simulation, you'll see that the next diagnostic output will be at
day 75, and not day 60. In other words, the counter starts from 0 with every
restart.

!!! note

If you care about accurate accumulated diagnostics, make sure to line up your
checkpoint and diagnostic frequencies.
89 changes: 54 additions & 35 deletions src/callbacks/get_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,21 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
period_dates = Dates.Month(parse(Int, first(months)))
output_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
compute_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
else
period_seconds = FT(time_to_seconds(period_str))
output_schedule = CAD.EveryDtSchedule(period_seconds)
compute_schedule = CAD.EveryDtSchedule(period_seconds)
period_dates =
CA.promote_period.(Dates.Second(period_seconds))
end

output_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)
compute_schedule = CAD.EveryCalendarDtSchedule(
period_dates;
reference_date = p.start_date,
)

if isnothing(output_name)
output_short_name = CAD.descriptive_short_name(
CAD.get_diagnostic_variable(short_name),
Expand Down Expand Up @@ -159,6 +160,25 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
end
diagnostics = collect(diagnostics)

periods_reductions = Set()
for diag in diagnostics
isa_reduction = !isnothing(diag.reduction_time_func)
isa_reduction || continue

if diag.output_schedule_func isa CAD.EveryDtSchedule
period = Dates.Second(diag.output_schedule_func.dt)
elseif diag.output_schedule_func isa CAD.EveryCalendarDtSchedule
period = diag.output_schedule_func.dt
else
continue
end

push!(periods_reductions, period)
end

periods_str = join(CA.promote_period.(periods_reductions), ", ")
@info "Saving accumulated diagnostics to disk with frequency: $(periods_str)"

for writer in writers
writer_str = nameof(typeof(writer))
diags_with_writer =
Expand All @@ -169,9 +189,28 @@ function get_diagnostics(parsed_args, atmos_model, Y, p, dt, t_start)
@info "$writer_str: $diags_outputs"
end

return diagnostics, writers
return diagnostics, writers, periods_reductions
end

function checkpoint_frequency_from_parsed_args(dt_save_state_to_disk::String)
if occursin("months", dt_save_state_to_disk)
months = match(r"^(\d+)months$", dt_save_state_to_disk)
isnothing(months) && error(
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
return Dates.Month(parse(Int, first(months)))
else
dt_save_state_to_disk = time_to_seconds(dt_save_state_to_disk)
if !(dt_save_state_to_disk == Inf)
# We use Millisecond to support fractional seconds, eg. 0.1
return Dates.Millisecond(1000dt_save_state_to_disk)
else
return Inf
end
end
end


function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
(; parsed_args, comms_ctx) = config
FT = eltype(params)
Expand Down Expand Up @@ -211,13 +250,10 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)

# Save dt_save_state_to_disk as a Dates.Period object. This is used to check
# if it is an integer multiple of other frequencies.
dt_save_state_to_disk_dates = Dates.today() # Value will be overwritten
if occursin("months", parsed_args["dt_save_state_to_disk"])
months = match(r"^(\d+)months$", parsed_args["dt_save_state_to_disk"])
isnothing(months) && error(
"$(period_str) has to be of the form <NUM>months, e.g. 2months for 2 months",
)
dt_save_state_to_disk_dates = Dates.Month(parse(Int, first(months)))
dt_save_state_to_disk_dates = checkpoint_frequency_from_parsed_args(
parsed_args["dt_save_state_to_disk"],
)
if dt_save_state_to_disk_dates != Inf
schedule = CAD.EveryCalendarDtSchedule(
dt_save_state_to_disk_dates;
reference_date = p.start_date,
Expand All @@ -230,23 +266,6 @@ function get_callbacks(config, sim_info, atmos, params, Y, p, t_start)
(integrator) -> save_state_to_disk_func(integrator, output_dir)
end
callbacks = (callbacks..., SciMLBase.DiscreteCallback(cond, affect!))
else
dt_save_state_to_disk =
time_to_seconds(parsed_args["dt_save_state_to_disk"])
if !(dt_save_state_to_disk == Inf)
# We use Millisecond to support fractional seconds, eg. 0.1
dt_save_state_to_disk_dates =
Dates.Millisecond(dt_save_state_to_disk)
callbacks = (
callbacks...,
call_every_dt(
(integrator) ->
save_state_to_disk_func(integrator, output_dir),
dt_save_state_to_disk;
skip_first = sim_info.restart,
),
)
end
end

if is_distributed(comms_ctx)
Expand Down
31 changes: 23 additions & 8 deletions src/solver/type_getters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,16 +757,31 @@ function get_simulation(config::AtmosConfig)
# Initialize diagnostics
if config.parsed_args["enable_diagnostics"]
s = @timed_str begin
scheduled_diagnostics, writers = get_diagnostics(
config.parsed_args,
atmos,
Y,
p,
sim_info.dt,
t_start,
)
scheduled_diagnostics, writers, periods_reductions =
get_diagnostics(
config.parsed_args,
atmos,
Y,
p,
sim_info.dt,
t_start,
)
end
@info "initializing diagnostics: $s"

# Check for consistency between diagnostics and checkpoints
checkpoint_frequency = checkpoint_frequency_from_parsed_args(
config.parsed_args["dt_save_state_to_disk"],
)

if checkpoint_frequency != Inf
if any(
x -> !CA.isdivisible(checkpoint_frequency, x),
periods_reductions,
)
@warn "Some accumulated diagnostics might not be evenly divisible by the checkpointing frequency ($(CA.promote_period(checkpoint_frequency)))"
end
end
else
writers = nothing
end
Expand Down
49 changes: 49 additions & 0 deletions src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,52 @@ function isdivisible(
# have any common divisor)
return isinteger(Dates.Day(1) / dt_small)
end

"""
promote_period(period::Dates.Period)
Promote a period to the largest possible period type.
This function attempts to represent a given `Period` using the largest possible
unit of time. For example, a period of 24 hours will be promoted to 1 day. If
a clean promotion is not possible, return the input as it is.
# Examples
```julia-repl
julia> promote_period(Hour(24))
1 day
julia> promote_period(Day(14))
2 weeks
julia> promote_period(Second(86401))
86401 seconds
julia> promote_period(Millisecond(1))
1 millisecond
```
"""
function promote_period(period::Dates.Period)
charleskawczynski marked this conversation as resolved.
Show resolved Hide resolved
ms = Int(Dates.toms(period))
# Hard to do this with varying periods like Month/Year...
PeriodTypes = [
Dates.Week,
Dates.Day,
Dates.Hour,
Dates.Minute,
Dates.Second,
Dates.Millisecond,
]
for PeriodType in PeriodTypes
charleskawczynski marked this conversation as resolved.
Show resolved Hide resolved
period_ms = Int(Dates.toms(PeriodType(1)))
if ms % period_ms == 0
# Millisecond will always match, if nothing else matches
return PeriodType(ms // period_ms)
end
end
end

function promote_period(period::Dates.OtherPeriod)
# For varying periods, we just return them as they are
return period
end
8 changes: 8 additions & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,11 @@ end
@test test_cent_space == cent_space
@test test_face_space == face_space
end

@testset "promote_period" begin
@test CA.promote_period(Dates.Hour(24)) == Dates.Day(1)
@test CA.promote_period(Dates.Day(14)) == Dates.Week(2)
@test CA.promote_period(Dates.Millisecond(1)) == Dates.Millisecond(1)
@test CA.promote_period(Dates.Minute(120)) == Dates.Hour(2)
@test CA.promote_period(Dates.Second(3600)) == Dates.Hour(1)
end
Loading