Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds warning message when deterministic training loss stagnates too quickly in partial BNNs #24

Merged
merged 6 commits into from
Nov 19, 2024

Conversation

sarah-allec
Copy link
Contributor

@sarah-allec sarah-allec commented Nov 5, 2024

Context

Certain deterministic NN hyperparameters may cause overfitting that manifests in the training loss initially decreasing very rapidly and stagnating early. A warning message to the user with suggested solutions would be helpful. Closes #11

Description

After the first epoch, the change in training loss is monitored and any time it drops more than 25% (see figure for justification of choice of 25%), a warning message is printed to the user:

UserWarning: The deterministic training loss is decreasing rapidly - learning and accuracy may be improved by increasing the batch size, adjusting MAP sigma, or modifying the learning rate.

dnn_loss

Changes in the codebase

  1. Added a function called monitor_dnn_loss in neurobayes/utils/utils.py that prints warning when loss has decreased by 25% at any epoch.
  2. Added a call to monitor_dnn_loss in flax_nets/deterministic_nn.py in the training loop (DeterministicNN.train()).

Copy link
Owner

@ziatdinovmax ziatdinovmax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. A few suggestions:

  • Please clarify the monitor_dnn_loss function's return behavior - either add an explicit return value or document that it intentionally returns None.
  • Please Include a length check (if len(loss) > 2:) to avoid potential IndexError with np.diff for the edge case when loss has fewer than two elements.

@sarah-allec
Copy link
Contributor Author

I implemented the suggestions - thank you! If everything looks good, I will submit the PR.

@ziatdinovmax ziatdinovmax marked this pull request as ready for review November 19, 2024 17:00
Copy link
Owner

@ziatdinovmax ziatdinovmax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! There is a seemingly unrelated issue with python-3.10 tests failing, which I will need to figure out later, but this one is ready to be merged.

@ziatdinovmax ziatdinovmax merged commit e735fea into ziatdinovmax:main Nov 19, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add warning message when deterministic training loss stagnates too quickly in partial BNNs
2 participants