Skip to content

Commit

Permalink
Made training function more robust to NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
msainsburydale committed Oct 31, 2024
1 parent c4eab60 commit f788541
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -859,15 +859,24 @@ saves the weights of the best network (measured by validation loss) as
'best_network.bson'.
"""
function _savebestmodel(path::String)
# Load the risk as a function of epoch
loss_per_epoch = load(joinpath(path, "loss_per_epoch.bson"), @__MODULE__)[:loss_per_epoch]

# The first row is the risk evaluated for the initial neural network, that
# is, the network at epoch 0. Since Julia starts indexing from 1, we
# subtract 1 from argmin().
best_epoch = argmin(loss_per_epoch[:, 2]) -1
# Replace NaN with Inf so they won't interfere with finding the minimum risk
loss_per_epoch .= ifelse.(isnan.(loss_per_epoch), Inf, loss_per_epoch)

# Find the epoch in which the validation risk was minimised
best_epoch = argmin(loss_per_epoch[:, 2])

# Subtract 1 since the first row is the risk evaluated for the initial neural network, that
# is, the network at epoch 0
best_epoch -= 1

# Save the best network
load_path = joinpath(path, "network_epoch$(best_epoch).bson")
save_path = joinpath(path, "best_network.bson")
cp(load_path, save_path, force = true)

return nothing
end

Expand Down

0 comments on commit f788541

Please sign in to comment.