Skip to content

Commit

Permalink
warning -> assert
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 20, 2024
1 parent 3ba9362 commit 0cd5e29
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def check_set_env_cublas_workspace_config():
expected_val = ":4096:8"
actual_val = os.getenv("CUBLAS_WORKSPACE_CONFIG")
if actual_val is None:
# https://docs.nvidia.com/cuda/archive/12.4.0/cublas/index.html#results-reproducibility
logging.info(
f"Environment variable CUBLAS_WORKSPACE_CONFIG not set. Setting to '{expected_val}' to ensure reproducibility."
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = expected_val
elif actual_val != expected_val:
logging.warning(
f"Environment variable CUBLAS_WORKSPACE_CONFIG is set to {actual_val}, which may affect reproducibility. "
f"We recommend setting it to '{expected_val}' to ensure reproducibility."
)
correct_values = [expected_val, ":16:8"]
assert actual_val in correct_values, (
f"Environment variable CUBLAS_WORKSPACE_CONFIG is set to {actual_val}, which is incompatibe with reproducible training. "
f"Please set it to one of the following values: {correct_values}. "
f"See https://docs.nvidia.com/cuda/archive/12.4.0/cublas/index.html#results-reproducibility for more information."
)


def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]:
Expand Down

0 comments on commit 0cd5e29

Please sign in to comment.