Skip to content

Commit

Permalink
cosmetic changes in configs & args
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 24, 2024
1 parent e8545af commit f154b81
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 130 deletions.
3 changes: 2 additions & 1 deletion delphi/test_configs/debug.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
"dataset": {
"name": "delphi-suite/v0-tinystories-v2-clean-tokenized"
},
"out_repo_id": ""
"out_repo": "",
"wandb": ""
}
1 change: 0 additions & 1 deletion delphi/train/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@
dot_notation_to_dict,
get_user_config_path,
)
from .wandb_config import WandbConfig
13 changes: 8 additions & 5 deletions delphi/train/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .adam_config import AdamConfig
from .dataset_config import DatasetConfig
from .debug_config import DebugConfig
from .wandb_config import WandbConfig


@beartype
Expand All @@ -23,7 +22,7 @@ class TrainingConfig:
max_seq_len: int = field(metadata={"help": "max sequence length"})
# meta
run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
output_dir: str = field(
out_dir: str = field(
default=os.path.join(platformdirs.user_data_dir(appname="delphi"), run_name),
metadata={"help": "output directory"},
)
Expand Down Expand Up @@ -97,9 +96,13 @@ class TrainingConfig:
)

# third party
wandb: Optional[WandbConfig] = None
out_repo_id: str = field(
metadata={"help": "set to empty string to not push to repo"},
wandb: str = field(
metadata={
"help": "wandb config in 'entity/project' form. Set to empty string to not use wandb."
},
)
out_repo: str = field(
metadata={"help": "HF repo id. Set to empty string to not push to repo."},
)

# debug
Expand Down
26 changes: 1 addition & 25 deletions delphi/train/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,6 @@ def build_config_dict_from_files(config_files: list[Path]) -> dict[str, Any]:
return combined_config


def set_backup_vals(config: dict[str, Any], config_files: list[Path]):
"""
Convenience default values for run_name and output_dir based on config file (if exactly one passed)
If the user is using 1 config file and has not set a run_name, we set it to the filename.
Likewise for output_dir, we set it to a user-specific directory based on the run_name.
"""
if len(config_files) == 1:
prefix = f"{config_files[0].stem}__"
else:
prefix = ""
if "run_name" not in config:
run_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
config["run_name"] = f"{prefix}{run_time}"
logging.info(f"Setting run_name to {config['run_name']}")
if "output_dir" not in config:
config["output_dir"] = os.path.join(
platformdirs.user_data_dir(appname="delphi"), config["run_name"]
)
logging.info(f"Setting output_dir to {config['output_dir']}")


def cast_types(config: dict[str, Any], target_dataclass: Type):
"""
user overrides are passed in as strings, so we need to cast them to the correct type
Expand Down Expand Up @@ -118,13 +96,11 @@ def build_config_from_files_and_overrides(
(we expect this to be passed as strings w/o type hints from a script argument:
e.g. `--overrides model_config.hidden_size=42 run_name=foo`)
3. Merge in overrides to config_dict, taking precedence over all config_files values.
4. Set backup values (for run_name and output_dir) if they are not already set.
5. Build the TrainingConfig object from the final config dict and return it.
4. Build the TrainingConfig object from the final config dict and return it.
"""
combined_config = build_config_dict_from_files(config_files)
cast_types(overrides, TrainingConfig)
merge_two_dicts(merge_into=combined_config, merge_from=overrides)
set_backup_vals(combined_config, config_files)
return from_dict(TrainingConfig, combined_config, config=dacite_config(strict=True))


Expand Down
11 changes: 0 additions & 11 deletions delphi/train/config/wandb_config.py

This file was deleted.

6 changes: 3 additions & 3 deletions delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def setup_training(config: TrainingConfig):
logging.info("Setting up training...")
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.out_dir, exist_ok=True)

# torch misc - TODO: check if this is actually needed
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
Expand All @@ -36,11 +36,11 @@ def setup_training(config: TrainingConfig):

# wandb setup
if config.wandb:
init_wandb(config=config)
init_wandb(config)

if config.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
tokenizer.save_pretrained(Path(config.output_dir) / "tokenizer")
tokenizer.save_pretrained(Path(config.out_dir) / "tokenizer")


def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext]:
Expand Down
14 changes: 7 additions & 7 deletions delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def save_results(
config, context (e.g. hardware), training step, etc
"""
iter_name = "main" if final else f"iter{train_results.iter_num}"
output_dir = Path(config.output_dir)
results_path = output_dir / iter_name
out_dir = Path(config.out_dir)
results_path = out_dir / iter_name
logging.info(f"saving checkpoint to {results_path}")
results_path.mkdir(parents=True, exist_ok=True)
with open(results_path / "training_config.json", "w") as file:
Expand All @@ -220,19 +220,19 @@ def save_results(
json.dump(training_state_dict, file, indent=2)
with open(results_path / "run_context.json", "w") as file:
json.dump(run_context.asdict(), file, indent=2)
if (tokenizer_dir := output_dir / "tokenizer").exists():
if (tokenizer_dir := out_dir / "tokenizer").exists():
for src_file in tokenizer_dir.iterdir():
if src_file.is_file():
dest_file = results_path / src_file.name
shutil.copy2(src_file, dest_file)
if config.out_repo_id:
if config.out_repo:
try:
api = HfApi()
api.create_repo(config.out_repo_id, exist_ok=True)
api.create_branch(config.out_repo_id, branch=iter_name, exist_ok=True)
api.create_repo(config.out_repo, exist_ok=True)
api.create_branch(config.out_repo, branch=iter_name, exist_ok=True)
api.upload_folder(
folder_path=results_path,
repo_id=config.out_repo_id,
repo_id=config.out_repo,
revision=iter_name,
)
except Exception as e:
Expand Down
16 changes: 4 additions & 12 deletions delphi/train/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from dataclasses import asdict

import wandb
Expand All @@ -8,19 +7,12 @@
from .utils import ModelTrainingState


def silence_wandb():
logging.info("silencing wandb output")
os.environ["WANDB_SILENT"] = "true"


def init_wandb(config: TrainingConfig):
# if log level < debug, silence wandb
assert config.wandb is not None
if logging.getLogger().level > logging.INFO or config.wandb.silence:
silence_wandb()
assert "/" in config.wandb, "wandb should be in the 'entity/project' form"
wandb_entity, wandb_project = config.wandb.split("/")
wandb.init(
entity=config.wandb.entity,
project=config.wandb.project,
entity=wandb_entity,
project=wandb_project,
name=config.run_name,
config=asdict(config),
)
Expand Down
16 changes: 8 additions & 8 deletions delphi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def hf_split_to_split_name(split: str) -> str:


def load_dataset_split_features(
repo_id: str,
path: str,
split: str,
features: Features,
) -> Dataset:
dataset = load_dataset(
repo_id,
path,
split=split,
features=features,
)
Expand All @@ -25,28 +25,28 @@ def load_dataset_split_features(


def load_dataset_split_string_feature(
repo_id: str,
path: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading string dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
print(f"{path=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
path,
split,
Features({feature_name: Value("string")}),
)


def load_dataset_split_sequence_int32_feature(
repo_id: str,
path: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading sequence int32 dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
print(f"{path=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
path,
split,
Features({feature_name: Sequence(Value("int32"))}),
)
Expand Down
14 changes: 6 additions & 8 deletions scripts/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,21 @@ def set_logging(args: argparse.Namespace):

def setup_parser() -> argparse.ArgumentParser:
# Setup argparse
parser = argparse.ArgumentParser(description="Train a delphi model")
parser = argparse.ArgumentParser(
description="Train a delphi model", allow_abbrev=False
)
parser.add_argument(
"--config_files",
"--config_file",
"-c",
"config_files",
help=(
"Path to json file(s) containing config values. Specific values can be overridden with --overrides. "
"e.g. `--config_files primary_config.json secondary_config.json"
"Path to json file(s) containing config values, e.g. 'primary_config.json secondary_config.json'."
),
type=str,
required=False,
nargs="*",
)
parser.add_argument(
"--overrides",
help=(
"Override config values with comma-separated declarations. "
"Override config values with space-separated declarations. "
"e.g. `--overrides model_config.hidden_size=42 run_name=foo`"
),
type=str,
Expand Down
55 changes: 28 additions & 27 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
from delphi.tokenization import get_tokenized_chunks

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="", allow_abbrev=False)
parser = argparse.ArgumentParser(
description="Tokenize a text dataset using a specific tokenizer",
allow_abbrev=False,
)

parser.add_argument(
"--in-repo-id",
"--in-dataset",
"-i",
type=str,
required=True,
help="Text dataset from huggingface to tokenize",
help="Dataset you want to tokenize. Local path or HF repo id",
)
parser.add_argument(
"--feature",
"-f",
type=str,
required=True,
help="Name of the column containing text documents in the input dataset",
help="Name of the feature (column) containing text documents in the input dataset",
)
parser.add_argument(
"--split",
Expand All @@ -35,18 +38,6 @@
required=True,
help="Split of the dataset to be tokenized, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
help="Local directory to save the resulting dataset",
)
parser.add_argument(
"--out-repo-id",
type=str,
required=False,
help="HF repo id to upload the resulting dataset",
)
parser.add_argument(
"--tokenizer",
"-t",
Expand All @@ -59,29 +50,39 @@
"-l",
type=int,
required=True,
help="Context size of the tokenized dataset as input of the model",
help="Length of the tokenized sequences",
)
parser.add_argument(
"--batch-size",
"-b",
type=int,
default=50,
help="Size of input into batched tokenization",
help="How many text documents to tokenize at once (default: 50)",
)
parser.add_argument(
"--chunk-size",
"-c",
type=int,
default=200_000,
help="Size of the parquet chunks uploaded to HuggingFace",
help="Maximum number of tokenized sequences in a single parquet file (default: 200_000)",
)
parser.add_argument(
"--out-dir",
type=str,
required=False,
help="Local directory to save the resulting dataset",
)
parser.add_argument(
"--out-repo",
type=str,
required=False,
help="HF repo id to upload the resulting dataset",
)
args = parser.parse_args()
assert (
args.out_repo_id or args.out_dir
), "You need to provide --out-repo-id or --out-dir"
assert args.out_repo or args.out_dir, "You need to provide --out-repo or --out-dir"

in_dataset_split = utils.load_dataset_split_string_feature(
args.in_repo_id, args.split, args.feature
args.in_dataset, args.split, args.feature
)
assert isinstance(in_dataset_split, Dataset)
print(f"Loading tokenizer from '{args.tokenizer}'...")
Expand All @@ -90,9 +91,9 @@
assert tokenizer.eos_token_id is not None, "Tokenizer must have a eos_token_id"

api = None
if args.out_repo_id:
if args.out_repo:
api = HfApi()
api.create_repo(repo_id=args.out_repo_id, repo_type="dataset", exist_ok=True)
api.create_repo(repo_id=args.out_repo, repo_type="dataset", exist_ok=True)
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)

Expand All @@ -115,11 +116,11 @@
ds_parquet_chunk = io.BytesIO()
ds_chunk.to_parquet(ds_parquet_chunk)
if api:
print(f"Uploading '{chunk_name}' to '{args.out_repo_id}'...")
print(f"Uploading '{chunk_name}' to '{args.out_repo}'...")
api.upload_file(
path_or_fileobj=ds_parquet_chunk,
path_in_repo=f"data/{chunk_name}",
repo_id=args.out_repo_id,
repo_id=args.out_repo,
repo_type="dataset",
)
print(f"Done saving/uploading '{chunk_name}'")
Loading

0 comments on commit f154b81

Please sign in to comment.