Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a fallback to change torch_dtype if cuda isn't available. #32

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
#--
#-- 23/01/2024 Lyaaaaa
#-- - Removed TOKENIZERS_PATH.
#--
#-- 08/05/2024 Lyaaaaa
#-- - Added TORCH_DTYPE_SAFETY.
#---------------------------------------------------------------------------

import logging
Expand Down Expand Up @@ -79,3 +82,5 @@
DEVICE_MAP = None # None/see documentation
TORCH_DTYPE = None # "Auto"/None/torch.dtype/See torch_dtype.py for more info.

# Safeguards
TORCH_DTYPE_SAFETY = True # True/False. If CUDA isn't available, will enforce Torch_Dtype to float32 to avoir error. See issue https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
11 changes: 11 additions & 0 deletions server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@
#-- - p_model_path is now the second parameter of __init__. p_parameters the third.
#-- - Added a log message to display the model's name and its path.
#-- - Added a log message to display if cuda is supported.
#--
#-- - 08/05/2024 Lyaaaaa
#-- - Updated _load_model to force (if config.TORCH_DTYPE_SAFETY is True)
#-- torch_dtype to be set to float32 if cuda isn't available.
#-- Because otherwise, it will lead to an error during generation.
#-- See https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
#------------------------------------------------------------------------------

from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
Expand Down Expand Up @@ -320,6 +326,11 @@ def _load_model(self):
logger.log.debug("Model settings:")
logger.log.debug(args)

if not self.is_cuda_available and config.TORCH_DTYPE_SAFETY:
logger.log.warn("Cuda isn't available.")
logger.log.warn("Setting torch_dtype to float 32 to avoid error.")
args["torch_dtype"] = Torch_Dtypes.dtypes.value[Torch_Dtypes.FLOAT_32.value]

self._Model = AutoModelForCausalLM.from_pretrained(self._model_path,
**args)
except Exception as e:
Expand Down