Skip to content

Commit

Permalink
revamp RunContext and add gpu_name (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak authored May 20, 2024
1 parent 4cef656 commit 38e65a5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 40 deletions.
39 changes: 30 additions & 9 deletions src/delphi/train/run_context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# get contextual information about a training run

from dataclasses import dataclass
import os

import torch
import transformers

import delphi


def get_auto_device_str() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"


@dataclass
class RunContext:
device: torch.device
torch_version: str
delphi_version: str
transformers_version: str
os: str
def __init__(self, device_str: str):
if device_str == "auto":
device_str = get_auto_device_str()
self.device = torch.device(device_str)
if self.device.type == "cuda":
assert torch.cuda.is_available()
self.gpu_name = torch.cuda.get_device_name(self.device)
elif self.device.type == "mps":
assert torch.backends.mps.is_available()
self.torch_version = torch.__version__
self.delphi_version = delphi.__version__
self.transformers_version = transformers.__version__
self.os = os.uname().version

def asdict(self) -> dict:
asdict = self.__dict__.copy()
asdict["device"] = str(self.device)
return asdict
15 changes: 2 additions & 13 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@

import torch
from tqdm import tqdm
from transformers import __version__ as transformers_version

from delphi import __version__ as delphi_version

from .checkpoint_step import log_and_save_checkpoint, should_save_checkpoint
from .config import TrainingConfig
from .run_context import RunContext
from .train_step import train_step
from .utils import (
ModelTrainingState,
get_device,
get_indices_for_epoch,
initialize_model_training_state,
set_lr,
Expand Down Expand Up @@ -46,15 +42,8 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext
logging.info("Config:")
for field in fields(config):
logging.info(f" {field.name}: {getattr(config, field.name)}")
# system
run_context = RunContext(
device=get_device(config.device),
torch_version=torch.__version__,
delphi_version=delphi_version,
transformers_version=transformers_version,
os=os.uname().version,
)
logging.debug(f"Run context: {run_context}")
run_context = RunContext(config.device)
logging.debug(f"Run context: {run_context.asdict()}")

# load data
logging.info("Loading data...")
Expand Down
19 changes: 1 addition & 18 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@ def setup_determinism(seed: int):
torch.manual_seed(seed)


def get_device(device_str: str = "auto") -> torch.device:
"""
Get torch device specified by device_str. May pass "auto" to set torch device automatically.
"""
# cuda if available; else mps if apple silicon; else cpu
if device_str == "auto":
if torch.cuda.is_available():
device_str = "cuda"
elif torch.backends.mps.is_available():
device_str = "mps"
else:
device_str = "cpu"
return torch.device(device_str)


def get_lr(
iter_num: int,
warmup_iters: int,
Expand Down Expand Up @@ -272,9 +257,7 @@ def save_results(
}
json.dump(training_state_dict, file, indent=2)
with open(os.path.join(results_path, "run_context.json"), "w") as file:
run_context_dict = asdict(run_context)
run_context_dict["device"] = str(run_context.device)
json.dump(run_context_dict, file, indent=2)
json.dump(run_context.asdict(), file, indent=2)
if config.out_repo_id:
api = HfApi()
api.create_repo(config.out_repo_id, exist_ok=True)
Expand Down

0 comments on commit 38e65a5

Please sign in to comment.