Skip to content

Commit

Permalink
validate_configs: overrides, init model
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 25, 2024
1 parent 4de7ca2 commit abd579b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
8 changes: 8 additions & 0 deletions delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from torch.optim import AdamW
from transformers import PreTrainedModel

from delphi.train.config import dot_notation_to_dict

from .config import TrainingConfig
from .run_context import RunContext

Expand Down Expand Up @@ -253,3 +255,9 @@ def init_model(model_config_dict: dict[str, Any], seed: int) -> PreTrainedModel:
model_params_dict = model_config_dict.copy()
model_params_dict.pop("model_class")
return model_class(config_class(**(model_params_dict)))


def overrides_to_dict(overrides: list[str]) -> dict[str, Any]:
# ["a.b.c=4", "foo=false"] to {"a": {"b": {"c": 4}}, "foo": False}
config_vars = {k: v for k, v in [x.split("=") for x in overrides if "=" in x]}
return dot_notation_to_dict(config_vars)
14 changes: 2 additions & 12 deletions scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
import logging
import sys
from pathlib import Path
from typing import Any

from delphi.train.config import (
build_config_from_files_and_overrides,
dot_notation_to_dict,
)
from delphi.train.config import build_config_from_files_and_overrides
from delphi.train.training import run_training
from delphi.train.utils import save_results
from delphi.train.utils import overrides_to_dict, save_results


def add_logging_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -77,12 +73,6 @@ def setup_parser() -> argparse.ArgumentParser:
return parser


def overrides_to_dict(overrides: list[str]) -> dict[str, Any]:
# ["a.b.c=4", "foo=false"] to {"a": {"b": {"c": 4}}, "foo": False}
config_vars = {k: v for k, v in [x.split("=") for x in overrides if "=" in x]}
return dot_notation_to_dict(config_vars)


def main():
parser = setup_parser()
args = parser.parse_args()
Expand Down
24 changes: 23 additions & 1 deletion scripts/validate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib

from delphi.train.config import build_config_from_files_and_overrides
from delphi.train.utils import init_model, overrides_to_dict


def get_config_path_with_base(config_path: pathlib.Path) -> list[pathlib.Path]:
Expand Down Expand Up @@ -33,15 +34,32 @@ def main():
type=str,
help="path to a training config json or directory of training config jsons",
)
parser.add_argument(
"--overrides",
help=(
"Override config values with space-separated declarations. "
"e.g. `--overrides model_config.hidden_size=42 run_name=foo`"
),
type=str,
required=False,
nargs="*",
default=[],
)
parser.add_argument("--init", help="initialize the model", action="store_true")
args = parser.parse_args()
config_paths = get_config_paths(args.config_path)
print(
f"validating configs: {' | '.join(str(config_path[-1]) for config_path in config_paths)}"
)
overrides = overrides_to_dict(args.overrides)
errors = []
sizes = []
for config_path in config_paths:
try:
build_config_from_files_and_overrides(config_path, {})
config = build_config_from_files_and_overrides(config_path, overrides)
if args.init:
model = init_model(config.model_config, seed=config.torch_seed)
sizes.append((config_path, model.num_parameters()))
except Exception as e:
errors.append((config_path, e))
continue
Expand All @@ -51,6 +69,10 @@ def main():
print(f" {config_path[-1]}: {e}")
else:
print("all configs loaded successfully")
if sizes:
print("model sizes:")
for config_path, size in sizes:
print(f" {config_path[-1]}: {size}")


if __name__ == "__main__":
Expand Down

0 comments on commit abd579b

Please sign in to comment.