Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load config files in order, later overrides earlier, lots more testing #108

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .vscode/launch.json
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"request": "launch",
"program": "scripts/run_training.py",
"console": "integratedTerminal",
"args": "--debug --loglevel 20"
"args": "--config_files src/delphi/static/configs/debug.json"
},
{
"name": "run_training custom",
Expand All @@ -36,7 +36,7 @@
"request": "launch",
"program": "scripts/run_training.py",
"console": "integratedTerminal",
"args": "--debug ${command:pickArgs}"
"args": "--config_files src/delphi/static/configs/debug.json ${command:pickArgs}"
}
]
}
22 changes: 7 additions & 15 deletions scripts/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from pathlib import Path
from typing import Any

from delphi.train.config import build_config_from_files_and_overrides
from delphi.train.config import (
build_config_from_files_and_overrides,
dot_notation_to_dict,
)
from delphi.train.training import run_training
from delphi.train.utils import save_results

Expand Down Expand Up @@ -52,6 +55,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train a delphi model")
parser.add_argument(
"--config_files",
"--config_file",
"-c",
Comment on lines 57 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"--config_files",
"--config_file",
"-c",
"--config",
"-c",

I don't care about backward compatibility, no one is using this yet

help=(
"Path to json file(s) containing config values. Specific values can be overridden with --overrides. "
"e.g. `--config_files primary_config.json secondary_config.json"
Expand All @@ -78,20 +83,7 @@ def setup_parser() -> argparse.ArgumentParser:
def overrides_to_dict(overrides: list[str]) -> dict[str, Any]:
# {"--overrides a.b.c=4 foo=false} to {"a": {"b": {"c": 4}}, "foo": False}
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
config_vars = {k: v for k, v in [x.split("=") for x in overrides if "=" in x]}
d = {}
for k, v in config_vars.items():
if v is None:
continue
# the laziest, most dangerous type conversion you've seen today
v = eval(v)
cur = d
subkeys = k.split(".")
for subkey in subkeys[:-1]:
if subkey not in cur:
cur[subkey] = {}
cur = cur[subkey]
cur[subkeys[-1]] = v
return d
return dot_notation_to_dict(config_vars)


def main():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"priority": -1,
"vocab_size": 4096,
"max_seq_len": 512,
"max_epochs": 10,
Expand Down
4 changes: 3 additions & 1 deletion src/delphi/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from importlib.resources import files
from pathlib import Path
from typing import cast

STATIC_ASSETS_DIR = files("delphi.static")
CONFIG_PRESETS_DIR = STATIC_ASSETS_DIR / "configs"
CONFIG_PRESETS_DIR = cast(Path, STATIC_ASSETS_DIR / "configs")
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved

CORPUS_DATASET = "delphi-suite/stories"
TINYSTORIES_TOKENIZED_HF_DATASET = "delphi-suite/v0-tinystories-v2-clean-tokenized"
1 change: 0 additions & 1 deletion src/delphi/static/configs/debug.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"priority": -1,
"vocab_size": 4096,
"max_seq_len": 512,
"max_epochs": 2,
Expand Down
1 change: 0 additions & 1 deletion src/delphi/static/configs/debug_mamba.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"priority": -1,
"vocab_size": 4096,
"max_seq_len": 512,
"max_epochs": 2,
Expand Down
1 change: 0 additions & 1 deletion src/delphi/static/configs/debug_transformers_bloom.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{
"priority": -1,
"vocab_size": 4096,
"max_seq_len": 512,
"max_epochs": 2,
Expand Down
2 changes: 1 addition & 1 deletion src/delphi/train/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
build_config_dict_from_files,
build_config_from_files,
build_config_from_files_and_overrides,
get_config_dicts_from_files,
dot_notation_to_dict,
get_preset_paths,
get_presets_by_name,
get_user_config_path,
Expand Down
106 changes: 77 additions & 29 deletions src/delphi/train/config/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import ast
import json
import logging
import os
from dataclasses import fields, is_dataclass
from datetime import datetime
from pathlib import Path
from typing import Type
from typing import _GenericAlias # type: ignore
from typing import Any, Type, TypeVar, Union

import platformdirs
from beartype.typing import Any, Iterable
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -14,22 +16,34 @@

from .training_config import TrainingConfig

T = TypeVar("T")

def _merge_dicts(merge_into: dict[str, Any], merge_from: dict[str, Any]):

def merge_two_dicts(merge_into: dict[str, Any], merge_from: dict[str, Any]):
"""recursively merge two dicts, with values in merge_from taking precedence"""
for key, val in merge_from.items():
if (
key in merge_into
and isinstance(merge_into[key], dict)
and isinstance(val, dict)
):
_merge_dicts(merge_into[key], val)
merge_two_dicts(merge_into[key], val)
else:
merge_into[key] = val


def merge_dicts(*dicts: dict[str, Any]) -> dict[str, Any]:
"""
Recursively merge multiple dictionaries, with later dictionaries taking precedence.
"""
merged = {}
for d in dicts:
merge_two_dicts(merged, d)
return merged


def get_preset_paths() -> Iterable[Path]:
return Path(CONFIG_PRESETS_DIR).glob("*.json") # type: ignore
return CONFIG_PRESETS_DIR.glob("*.json")


def get_user_config_path() -> Path:
Expand All @@ -45,28 +59,13 @@ def get_presets_by_name() -> dict[str, TrainingConfig]:
}


def get_config_dicts_from_files(config_files: list[Path]) -> list[dict[str, Any]]:
"""loads config files in ascending priority order"""
def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]:
config_dicts = []
for config_file in config_files:
logging.info(f"Loading {config_file}")
logging.debug(f"Loading {config_file}")
with open(config_file, "r") as f:
config_dicts.append(json.load(f))
return config_dicts


def combine_configs(configs: list[dict[str, Any]]) -> dict[str, Any]:
# combine configs dicts, with key "priority" setting precendence (higher priority overrides lower priority)
sorted_configs = sorted(configs, key=lambda c: c.get("priority", -999))
combined_config = dict()
for config in sorted_configs:
_merge_dicts(merge_into=combined_config, merge_from=config)
return combined_config


def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]:
configs_in_order = get_config_dicts_from_files(config_files)
combined_config = combine_configs(configs_in_order)
combined_config = merge_dicts(*config_dicts)
return combined_config


Expand Down Expand Up @@ -104,9 +103,7 @@ def set_backup_vals(config: dict[str, Any], config_files: list[Path]):
logging.info(f"Setting output_dir to {config['output_dir']}")


def log_config_recursively(
config: dict, logging_fn=logging.info, indent=" ", prefix=""
):
def log_config_recursively(config: dict, logging_fn, indent=" ", prefix=""):
for k, v in config.items():
if isinstance(v, dict):
logging_fn(f"{prefix}{k}")
Expand All @@ -115,17 +112,37 @@ def log_config_recursively(
logging_fn(f"{prefix}{k}: {v}")


def cast_types(config: dict[str, Any], target_dataclass: Type):
"""
user overrides are passed in as strings, so we need to cast them to the correct type
"""
dc_fields = {f.name: f for f in fields(target_dataclass)}
for k, v in config.items():
if k in dc_fields:
field = dc_fields[k]
field_type = _unoptionalize(field.type)
if is_dataclass(field_type):
cast_types(v, field_type)
elif isinstance(field_type, dict):
# for dictionaries, make best effort to cast values to the correct type
for _k, _v in v.items():
v[_k] = ast.literal_eval(_v)
else:
config[k] = field_type(v)


def build_config_from_files_and_overrides(
config_files: list[Path],
overrides: dict[str, Any],
) -> TrainingConfig:
combined_config = build_config_dict_from_files(config_files)
_merge_dicts(merge_into=combined_config, merge_from=overrides)
cast_types(overrides, TrainingConfig)
merge_two_dicts(merge_into=combined_config, merge_from=overrides)
set_backup_vals(combined_config, config_files)
filter_config_to_actual_config_values(TrainingConfig, combined_config)
logging.info("User-set config values:")
logging.debug("User-set config values:")
log_config_recursively(
combined_config, logging_fn=logging.info, prefix=" ", indent=" "
combined_config, logging_fn=logging.debug, prefix=" ", indent=" "
)
return from_dict(TrainingConfig, combined_config)

Expand All @@ -135,5 +152,36 @@ def build_config_from_files(config_files: list[Path]) -> TrainingConfig:


def load_preset(preset_name: str) -> TrainingConfig:
preset_path = Path(CONFIG_PRESETS_DIR) / f"{preset_name}.json" # type: ignore
preset_path = CONFIG_PRESETS_DIR / f"{preset_name}.json"
return build_config_from_files([preset_path])


def dot_notation_to_dict(vars: dict[str, Any]) -> dict[str, Any]:
"""
Convert {"a.b.c": 4, "foo": false} to {"a": {"b": {"c": 4}}, "foo": False}
"""
nested_dict = dict()
for k, v in vars.items():
if v is None:
continue
cur = nested_dict
subkeys = k.split(".")
for subkey in subkeys[:-1]:
if subkey not in cur:
cur[subkey] = {}
cur = cur[subkey]
cur[subkeys[-1]] = v
return nested_dict


def _unoptionalize(t: Type | _GenericAlias) -> Type:
"""unwrap `Optional[T]` to T"""
# Under the hood, `Optional` is really `Union[T, None]`. So we
# just check if this is a Union over two types including None, and
# return the other
if hasattr(t, "__origin__") and t.__origin__ is Union:
args = t.__args__
# Check if one of the Union arguments is type None
if len(args) == 2 and type(None) in args:
return args[0] if args[1] is type(None) else args[1]
return t
50 changes: 50 additions & 0 deletions tests/train/config/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional

import pytest

from delphi.constants import CONFIG_PRESETS_DIR
from delphi.train.config.utils import (
_unoptionalize,
build_config_from_files_and_overrides,
dot_notation_to_dict,
merge_dicts,
merge_two_dicts,
)


def test_merge_two_dicts():
dict1 = {"a": 1, "b": 2, "c": {"d": 3, "e": 4}}
dict2 = {"a": 5, "c": {"d": 6}}
merge_two_dicts(dict1, dict2)
assert dict1 == {"a": 5, "b": 2, "c": {"d": 6, "e": 4}}


def test_merge_dicts():
dict1 = {"a": 1, "b": 2, "c": {"d": 3, "e": 4}}
dict2 = {"a": 5, "c": {"d": 6}}
dict3 = {"a": 7, "b": 8, "c": {"d": 9, "e": 10}}
merged = merge_dicts(dict1, dict2, dict3)
assert merged == {"a": 7, "b": 8, "c": {"d": 9, "e": 10}}


def test_dot_notation_to_dict():
vars = {"a.b.c": 4, "foo": False}
result = dot_notation_to_dict(vars)
assert result == {"a": {"b": {"c": 4}}, "foo": False}


def test_build_config_from_files_and_overrides():
config_files = [CONFIG_PRESETS_DIR / "debug.json"]
overrides = {"model_config": {"hidden_size": 128}, "eval_iters": 5}
config = build_config_from_files_and_overrides(config_files, overrides)
# check overrides
assert config.model_config["hidden_size"] == 128
assert config.eval_iters == 5
# check base values
assert config.max_epochs == 2
assert config.data_config.train_sample_limit == 256


def test_unoptionalize():
assert _unoptionalize(int) == int
assert _unoptionalize(Optional[str]) == str
Loading