Skip to content

Commit

Permalink
Added a fallback to change torch_dtype if cuda isn't available.
Browse files Browse the repository at this point in the history
config.py:
- Added TORCH_DTYPE_SAFETY.

model.py:
- Updated _load_model to force (if config.TORCH_DTYPE_SAFETY is True)
torch_dtype to be set to float32 if cuda isn't available.
Because otherwise, it will lead to an error during generation.
See #31
  • Loading branch information
Lyaaaaaaaaaaaaaaa committed May 8, 2024
1 parent 92e52a9 commit e41567f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
5 changes: 5 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
#--
#-- 23/01/2024 Lyaaaaa
#-- - Removed TOKENIZERS_PATH.
#--
#-- 08/05/2024 Lyaaaaa
#-- - Added TORCH_DTYPE_SAFETY.
#---------------------------------------------------------------------------

import logging
Expand Down Expand Up @@ -79,3 +82,5 @@
DEVICE_MAP = None # None/see documentation
TORCH_DTYPE = None # "Auto"/None/torch.dtype/See torch_dtype.py for more info.

# Safeguards
TORCH_DTYPE_SAFETY = True # True/False. If CUDA isn't available, will enforce Torch_Dtype to float32 to avoir error. See issue https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
11 changes: 11 additions & 0 deletions server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@
#-- - p_model_path is now the second parameter of __init__. p_parameters the third.
#-- - Added a log message to display the model's name and its path.
#-- - Added a log message to display if cuda is supported.
#--
#-- - 08/05/2024 Lyaaaaa
#-- - Updated _load_model to force (if config.TORCH_DTYPE_SAFETY is True)
#-- torch_dtype to be set to float32 if cuda isn't available.
#-- Because otherwise, it will lead to an error during generation.
#-- See https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
#------------------------------------------------------------------------------

from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
Expand Down Expand Up @@ -320,6 +326,11 @@ def _load_model(self):
logger.log.debug("Model settings:")
logger.log.debug(args)

if not self.is_cuda_available and config.TORCH_DTYPE_SAFETY:
logger.log.warn("Cuda isn't available.")
logger.log.warn("Setting torch_dtype to float 32 to avoid error.")
args["torch_dtype"] = Torch_Dtypes.dtypes.value[Torch_Dtypes.FLOAT_32.value]

self._Model = AutoModelForCausalLM.from_pretrained(self._model_path,
**args)
except Exception as e:
Expand Down

0 comments on commit e41567f

Please sign in to comment.