Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch model extractor #298

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

PyTorch model extractor #298

wants to merge 6 commits into from

Conversation

mastoffel
Copy link
Collaborator

@mastoffel mastoffel commented Feb 12, 2025

  • adds a function to extract PyTorch models from AutoEmulate emulators

  • emulators can be pipelines, MultiOutputRegressors etc., this function checks all the options and extracts the underlying PyTorch model where possible and throws an error for other models

  • it also gives a message saying that datapreprocessing is better turned off when doing this and has to be done manually (as it can't be attached to the PyTorch model like it can be to a sci-kit learn model using a pipeline

  • it returns the model in eval mode

  • does not yet include other objects as discussed in Add "fit only pytorch models" flag #291 . Maybe we leave that to the next PR?

Copy link
Contributor

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  autoemulate
  utils.py 426, 435, 449
  autoemulate/emulators
  conditional_neural_process.py
  gaussian_process.py
  gaussian_process_mt.py
  tests
  test_pytorch_utils.py
Project Total  

This report was generated by python-coverage-comment-action

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 96.47059% with 3 lines in your changes missing coverage. Please review.

Project coverage is 94.22%. Comparing base (7a4dc72) to head (40dbed2).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
autoemulate/utils.py 85.71% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #298      +/-   ##
==========================================
+ Coverage   94.20%   94.22%   +0.02%     
==========================================
  Files          62       63       +1     
  Lines        3606     3691      +85     
==========================================
+ Hits         3397     3478      +81     
- Misses        209      213       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants