From 52779cf0acbd10e1c1805ab32d6e0cd43b27fc6c Mon Sep 17 00:00:00 2001 From: Jett Date: Sun, 19 May 2024 16:25:51 +0200 Subject: [PATCH] revamp RunContext and add gpu_name --- src/delphi/train/run_context.py | 39 +++++++++++++++++++++++++-------- src/delphi/train/training.py | 15 ++----------- src/delphi/train/utils.py | 19 +--------------- 3 files changed, 33 insertions(+), 40 deletions(-) diff --git a/src/delphi/train/run_context.py b/src/delphi/train/run_context.py index ced25908..548a3360 100644 --- a/src/delphi/train/run_context.py +++ b/src/delphi/train/run_context.py @@ -1,14 +1,35 @@ -# get contextual information about a training run - -from dataclasses import dataclass +import os import torch +import transformers + +import delphi + + +def get_auto_device_str() -> str: + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + return "cpu" -@dataclass class RunContext: - device: torch.device - torch_version: str - delphi_version: str - transformers_version: str - os: str + def __init__(self, device_str: str): + if device_str == "auto": + device_str = get_auto_device_str() + self.device = torch.device(device_str) + if self.device.type == "cuda": + assert torch.cuda.is_available() + self.gpu_name = torch.cuda.get_device_name(self.device) + elif self.device.type == "mps": + assert torch.backends.mps.is_available() + self.torch_version = torch.__version__ + self.delphi_version = delphi.__version__ + self.transformers_version = transformers.__version__ + self.os = os.uname().version + + def asdict(self) -> dict: + asdict = self.__dict__.copy() + asdict["device"] = str(self.device) + return asdict diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index a20132c3..85fda188 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -5,9 +5,6 @@ import torch from tqdm import tqdm -from transformers import __version__ as transformers_version - -from delphi import __version__ as delphi_version from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint from .config import TrainingConfig @@ -15,7 +12,6 @@ from .train_step import train_step from .utils import ( ModelTrainingState, - get_device, get_indices_for_epoch, initialize_model_training_state, set_lr, @@ -46,15 +42,8 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext logging.info("Config:") for field in fields(config): logging.info(f" {field.name}: {getattr(config, field.name)}") - # system - run_context = RunContext( - device=get_device(config.device), - torch_version=torch.__version__, - delphi_version=delphi_version, - transformers_version=transformers_version, - os=os.uname().version, - ) - logging.debug(f"Run context: {run_context}") + run_context = RunContext(config.device) + logging.debug(f"Run context: {run_context.asdict()}") # load data logging.info("Loading data...") diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 93b467ef..3bcdb958 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -48,21 +48,6 @@ def setup_determinism(seed: int): torch.manual_seed(seed) -def get_device(device_str: str = "auto") -> torch.device: - """ - Get torch device specified by device_str. May pass "auto" to set torch device automatically. - """ - # cuda if available; else mps if apple silicon; else cpu - if device_str == "auto": - if torch.cuda.is_available(): - device_str = "cuda" - elif torch.backends.mps.is_available(): - device_str = "mps" - else: - device_str = "cpu" - return torch.device(device_str) - - def get_lr( iter_num: int, warmup_iters: int, @@ -272,9 +257,7 @@ def save_results( } json.dump(training_state_dict, file, indent=2) with open(os.path.join(results_path, "run_context.json"), "w") as file: - run_context_dict = asdict(run_context) - run_context_dict["device"] = str(run_context.device) - json.dump(run_context_dict, file, indent=2) + json.dump(run_context.asdict(), file, indent=2) if config.out_repo_id: api = HfApi() api.create_repo(config.out_repo_id, exist_ok=True)