diff --git a/src/train.jl b/src/train.jl index 21ddf70..7a3c1fb 100644 --- a/src/train.jl +++ b/src/train.jl @@ -859,15 +859,24 @@ saves the weights of the best network (measured by validation loss) as 'best_network.bson'. """ function _savebestmodel(path::String) + # Load the risk as a function of epoch loss_per_epoch = load(joinpath(path, "loss_per_epoch.bson"), @__MODULE__)[:loss_per_epoch] - # The first row is the risk evaluated for the initial neural network, that - # is, the network at epoch 0. Since Julia starts indexing from 1, we - # subtract 1 from argmin(). - best_epoch = argmin(loss_per_epoch[:, 2]) -1 + # Replace NaN with Inf so they won't interfere with finding the minimum risk + loss_per_epoch .= ifelse.(isnan.(loss_per_epoch), Inf, loss_per_epoch) + + # Find the epoch in which the validation risk was minimised + best_epoch = argmin(loss_per_epoch[:, 2]) + + # Subtract 1 since the first row is the risk evaluated for the initial neural network, that + # is, the network at epoch 0 + best_epoch -= 1 + + # Save the best network load_path = joinpath(path, "network_epoch$(best_epoch).bson") save_path = joinpath(path, "best_network.bson") cp(load_path, save_path, force = true) + return nothing end