From 2e08bf208c0091b912486e4070be715d16094b57 Mon Sep 17 00:00:00 2001 From: Jett Date: Wed, 15 May 2024 13:46:09 +0200 Subject: [PATCH] handle CUBLAS_WORKSPACE_CONFIG env var --- src/delphi/train/training.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index a20132c3..179f656c 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -40,7 +40,25 @@ def setup_training(config: TrainingConfig): init_wandb(config=config) +def check_set_env_cublas_workspace_config(): + expected_val = ":4096:8" + actual_val = os.getenv("CUBLAS_WORKSPACE_CONFIG") + if actual_val is None: + # https://docs.nvidia.com/cuda/archive/12.4.0/cublas/index.html#results-reproducibility + logging.info( + f"Environment variable CUBLAS_WORKSPACE_CONFIG not set. Setting to '{expected_val}' to ensure reproducibility." + ) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = expected_val + elif actual_val != expected_val: + logging.warning( + f"Environment variable CUBLAS_WORKSPACE_CONFIG is set to {actual_val}, which may affect reproducibility. " + f"We recommend setting it to '{expected_val}' to ensure reproducibility." + ) + + def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]: + if torch.cuda.is_available(): + check_set_env_cublas_workspace_config() setup_training(config) logging.info("Starting training...") logging.info("Config:")