Skip to content

Commit

Permalink
Add best/worst single models in leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Jun 17, 2024
1 parent 3342dc3 commit 16c84d7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 58 deletions.
144 changes: 89 additions & 55 deletions experiments/ClimaEarth/user_io/leaderboard/cmip_rmse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,67 +35,101 @@ function best_single_model(RMSEs)
end

"""
RSME_stats(RMSEs)
worst_single_model(RMSEs)
Return the one model that has the overall largest error.
"""
function worst_single_model(RMSEs)
_, index = findmax(r -> abs.(values(r)), RMSEs)
return RMSEs[index]
end

"""
RMSE_stats(RMSEs)
RMSEs is the dictionary OTHER_MODELS_RMSEs.
Return:
- best single model
- worst single model
- "model" with all the medians
- "model" with all the best values
- "model" with all the worst values
"""
function RSME_stats(vecRMSEs)
# Collect into vectors that we can process independently
all_values = stack(values.(vecRMSEs))
ANN, DJF, JJA, MAM, SON = ntuple(i -> all_values[i, :], 5)

median_model = RMSEs(;
model_name = "Median",
ANN = median(ANN),
DJF = median(DJF),
JJA = median(JJA),
MAM = median(MAM),
SON = median(SON),
)

worst_model = RMSEs(;
model_name = "Worst",
ANN = maximum(abs.(ANN)),
DJF = maximum(abs.(DJF)),
JJA = maximum(abs.(JJA)),
MAM = maximum(abs.(MAM)),
SON = maximum(abs.(SON)),
)

best_model = RMSEs(;
model_name = "Best",
ANN = minimum(abs.(ANN)),
DJF = minimum(abs.(DJF)),
JJA = minimum(abs.(JJA)),
MAM = minimum(abs.(MAM)),
SON = minimum(abs.(SON)),
)

quantile25 = RMSEs(;
model_name = "Quantile 0.25",
ANN = quantile(ANN, 0.25),
DJF = quantile(DJF, 0.25),
JJA = quantile(JJA, 0.25),
MAM = quantile(MAM, 0.25),
SON = quantile(SON, 0.25),
)

quantile75 = RMSEs(;
model_name = "Quantile 0.75",
ANN = quantile(ANN, 0.75),
DJF = quantile(DJF, 0.75),
JJA = quantile(JJA, 0.75),
MAM = quantile(MAM, 0.75),
SON = quantile(SON, 0.75),
)

(; best_single_model = best_single_model(vecRMSEs), median_model, worst_model, best_model, quantile25, quantile75)
end
function RMSE_stats(dict_vecRMSEs)
stats = Dict()
# cumulative_error maps model_names with the total RMSE across metrics normalized by median(RMSE)
cumulative_error = Dict()
for (key, vecRMSEs) in dict_vecRMSEs
# Collect into vectors that we can process independently
all_values = stack(values.(vecRMSEs))
ANN, DJF, JJA, MAM, SON = ntuple(i -> all_values[i, :], 5)

median_model = RMSEs(;
model_name = "Median",
ANN = median(ANN),
DJF = median(DJF),
JJA = median(JJA),
MAM = median(MAM),
SON = median(SON),
)

worst_model = RMSEs(;
model_name = "Worst",
ANN = maximum(abs.(ANN)),
DJF = maximum(abs.(DJF)),
JJA = maximum(abs.(JJA)),
MAM = maximum(abs.(MAM)),
SON = maximum(abs.(SON)),
)

best_model = RMSEs(;
model_name = "Best",
ANN = minimum(abs.(ANN)),
DJF = minimum(abs.(DJF)),
JJA = minimum(abs.(JJA)),
MAM = minimum(abs.(MAM)),
SON = minimum(abs.(SON)),
)

quantile25 = RMSEs(;
model_name = "Quantile 0.25",
ANN = quantile(ANN, 0.25),
DJF = quantile(DJF, 0.25),
JJA = quantile(JJA, 0.25),
MAM = quantile(MAM, 0.25),
SON = quantile(SON, 0.25),
)

quantile75 = RMSEs(;
model_name = "Quantile 0.75",
ANN = quantile(ANN, 0.75),
DJF = quantile(DJF, 0.75),
JJA = quantile(JJA, 0.75),
MAM = quantile(MAM, 0.75),
SON = quantile(SON, 0.75),
)

for rmse in vecRMSEs
haskey(cumulative_error, cumulative_error) || (cumulative_error[rmse.model_name] = 0.0)
cumulative_error[rmse.model_name] += sum(values(rmse) ./ values(median_model))
end

for short_name in short_names
COMPARISON_RMSEs[short_name] = RSME_stats(OTHER_MODELS_RMSEs[short_name])
stats[key] = (;
best_single_model = best_single_model(vecRMSEs),
worst_single_model = worst_single_model(vecRMSEs),
median_model,
worst_model,
best_model,
quantile25,
quantile75,
)
end

_, absolute_best_model = findmin(cumulative_error)
_, absolute_worst_model = findmax(cumulative_error)

return (; stats, absolute_best_model, absolute_worst_model)
end

const COMPARISON_RMSEs_STATS = RMSE_stats(OTHER_MODELS_RMSEs)
14 changes: 11 additions & 3 deletions experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import CairoMakie
const OBS_DS = Dict()
const SIM_DS_KWARGS = Dict()
const OTHER_MODELS_RMSEs = Dict()
const COMPARISON_RMSEs = Dict()

function preprocess_pr_fn(data)
# -1 kg/m/s2 -> 1 mm/day
Expand Down Expand Up @@ -167,6 +166,8 @@ function plot_leaderboard(rmses; output_path)
# models compared, and there is one row per variable
squares = zeros(NUM_BOXES * NUM_MODELS, num_variables)

(; absolute_best_model, absolute_worst_model) = COMPARISON_RMSEs_STATS

for (var_num, rmse) in enumerate(rmses)
short_name = rmse.ANN.attributes["var_short_name"]
units = rmse.ANN.attributes["units"]
Expand All @@ -178,7 +179,11 @@ function plot_leaderboard(rmses; output_path)
)

# Against other models
(; best_single_model, median_model, worst_model, best_model) = COMPARISON_RMSEs[short_name]

(; median_model) = COMPARISON_RMSEs_STATS.stats[short_name]

best_single_model = first(filter(x -> x.model_name == absolute_best_model, OTHER_MODELS_RMSEs[short_name]))
worst_single_model = first(filter(x -> x.model_name == absolute_worst_model, OTHER_MODELS_RMSEs[short_name]))

squares[begin:NUM_BOXES, end - var_num + 1] .= values(rmse) ./ values(median_model)
squares[(NUM_BOXES + 1):end, end - var_num + 1] .= values(best_single_model) ./ values(median_model)
Expand All @@ -190,7 +195,7 @@ function plot_leaderboard(rmses; output_path)
label = median_model.model_name,
color = :black,
marker = :hline,
markersize = 15,
markersize = 10,
)

categories = vcat(map(_ -> collect(1:5), 1:length(OTHER_MODELS_RMSEs[short_name]))...)
Expand All @@ -206,6 +211,9 @@ function plot_leaderboard(rmses; output_path)
whiskerlinewidth = 1,
)

CairoMakie.scatter!(ax, 1:5, values(best_single_model), label = absolute_best_model)
CairoMakie.scatter!(ax, 1:5, values(worst_single_model), label = absolute_worst_model)

# If we want to plot other models
# for model in OTHER_MODELS_RMSEs[short_name]
# CairoMakie.scatter!(ax, 1:5, values(model), marker = :hline)
Expand Down

0 comments on commit 16c84d7

Please sign in to comment.