Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

half and mixed precision inference #442

Merged
merged 9 commits into from
Jan 16, 2025

Conversation

ArneBinder
Copy link
Owner

@ArneBinder ArneBinder commented Jan 13, 2025

This PR implements half and mixed precision inference by adding the following boolean parameters to the Pipeline:

  • half_precision_model: Whether or not to use half precision model. If set to True, the model will be cast to half precision on supported devices (torch.float16 on cuda and torch.bfloat16 on cpu for now, following torch.get_autocast_dtype()). This can reduce the memory consumption and improve the inference speed, but may lead to numerical instability.
  • half_precision_ops: Whether or not to use half precision operations. If set to True, the model will be run with half precision operations via torch.autocast.

Since torch.float16 can not be natively converted to numpy, this PR adds float() casting to several model outputs (probabilities) before calling .numpy(). in unbatch_output() of several taskmodules. Also, this fixes some spelling mistakes in the Pipeline documentation.

TODO:

  • test in downstream setup, e.g. inference on drugprot.
    • Outcome: effective for larger batch sizes (>1)
batch_size half_precision_model half_precision_ops inference_time speed_up
32 True True 32.8 2.4
32 True False 26.4 3.0
32 False True 28.8 2.7
32 False False 79.0 1.0
256 True True 23.6 3.5
256 True False 21.2 3.9
256 False True 23.7 3.5
256 False False 83.6 1.0
512 True True 24.4 3.7
512 True False 22.7 4.0
512 False True 25.0 3.6
512 False False 90.7 1.0
1024 True True 26.9 3.7
1024 True False 23.7 4.1
1024 False True 26.7 3.7
1024 False False 98.3 1.0

@ArneBinder ArneBinder force-pushed the pipeline/half_and_mixed_precision branch from c9ad301 to 7e11234 Compare January 13, 2025 21:25
@ArneBinder ArneBinder merged commit 9ab36d8 into main Jan 16, 2025
6 checks passed
@ArneBinder ArneBinder deleted the pipeline/half_and_mixed_precision branch January 16, 2025 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant