Skip to content

Commit

Permalink
config help as comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 25, 2024
1 parent c180ca6 commit 16f8704
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 105 deletions.
30 changes: 12 additions & 18 deletions delphi/train/config/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,21 @@
@beartype
@dataclass(frozen=True)
class DatasetConfig:
name: str = field(
metadata={"help": "tokenized dataset on huggingface to use for train"},
)
feature: str = field(
default="tokens",
metadata={
"help": "feature in the train dataset to use for train; should be a list of max_seq_len+1 token ints"
},
)
train_split: str = field(
default="train",
metadata={"help": "split of the dataset to use for training"},
)
validation_split: str = field(
default="validation",
metadata={"help": "split of the dataset to use for validation"},
)
# tokenized dataset; HF repo id or local directory
path: str

# feature in the dataset; should be a list of <= max_seq_len token ints
feature: str = "tokens"

# split of the dataset to use for training
train_split: str = "train"

# split of the dataset to use for validation
validation_split: str = "validation"

def _load(self, split) -> Dataset:
ds = utils.load_dataset_split_sequence_int32_feature(
self.name, split, self.feature
self.path, split, self.feature
)
ds.set_format("torch")
return ds
Expand Down
143 changes: 57 additions & 86 deletions delphi/train/config/training_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,96 +14,67 @@
@beartype
@dataclass(frozen=True, kw_only=True)
class TrainingConfig:
model_config: dict[str, Any] = field(
metadata={
"help": "model config; class_name=name of model class in transformers, everything else is kwargs for the corresponding model config"
},
)
max_seq_len: int = field(metadata={"help": "max sequence length"})
# meta
# model config; class_name=name of model class in transformers, everything else is kwargs for the corresponding model config
model_config: dict[str, Any]

max_seq_len: int
run_name: str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
out_dir: str = field(
default=os.path.join(platformdirs.user_data_dir(appname="delphi"), run_name),
metadata={"help": "output directory"},
)

# device
device: str = field(
default="auto", metadata={"help": "device to use (cuda, mps, cpu)"}
)

# checkpoints, logging, eval
checkpoint_interval: int = field(
default=2000, metadata={"help": "checkpoint every N iters"}
)
extra_checkpoint_iters: list[int] = field(
default_factory=list,
metadata={"help": "manually list iterations to save checkpoints on"},
)
log_interval: int = field(default=1, metadata={"help": "log every N iters"})
eval_iters: int = field(default=100, metadata={"help": "use N iters for each eval"})

# resume from checkpoint
resume_from_path: Optional[str] = field(
default=None,
metadata={
"help": "path to a checkpoint to resume from (if init_from=='resume')"
},
)

# data
batch_size: int = field(
default=64,
metadata={
"help": "number of samples used to compute the gradient for a single optimizer step"
},
)

# training
max_epochs: int = field(
default=10, metadata={"help": "total number of training epochs"}
)
grad_clip: float = field(
default=1.0,
metadata={"help": "clip gradients at this value, or disable if == 0.0"},
)
gradient_accumulation_steps: int = field(
default=1,
metadata={
"help": "if > 1 reduces memory usage by computing gradient in microbatches"
},
)
out_dir: str = os.path.join(platformdirs.user_data_dir(appname="delphi"), run_name)

# device to use (cuda, mps, cpu)
device: str = "auto"

# checkpoint every N iters
checkpoint_interval: int = 2000

# manually list iterations to save checkpoints on
extra_checkpoint_iters: list[int] = field(default_factory=list)

# log every N iters
log_interval: int = 1

# use N iters for each eval
eval_iters: int = 100

# path to a checkpoint to resume from (if init_from=='resume')
resume_from_path: Optional[str] = None

# number of samples used to compute the gradient for a single optimizer step
batch_size: int = 64

# total number of training epochs
max_epochs: int = 10

# clip gradients at this value, or disable if == 0.0
grad_clip: float = 1.0

# if > 1 reduces memory usage by computing gradient in microbatches
gradient_accumulation_steps: int = 1

# (adamw) optimizer
adam: AdamConfig = field(default_factory=AdamConfig)

# reproducibility
batch_ordering_seed: int = field(
metadata={"help": "seed used for pseudorandomly sampling data during training"},
)
torch_seed: int = field(metadata={"help": "seed used for torch"})
# seed used for pseudorandomly sampling data during training
batch_ordering_seed: int

# seed used for torch
torch_seed: int

# whether to save the optimizer state with each checkpoint
# this is twice as large as the model, but allows to resume training in a reproducible way
save_optimizer: bool = True

# data
dataset: DatasetConfig = field(
metadata={"help": "specify training and validation data"},
)

tokenizer: str = field(
default="",
metadata={
"help": "HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training"
},
)

# third party
wandb: str = field(
metadata={
"help": "wandb config in 'entity/project' form. Set to empty string to not use wandb."
},
)
out_repo: str = field(
metadata={"help": "HF repo id. Set to empty string to not push to repo."},
)

# debug
# specify training and validation data
dataset: DatasetConfig

# HF repo id or local directory containing the tokenizer. Used only to upload it to HF with the model, not for training
tokenizer: str = ""

# wandb config in 'entity/project' form. Set to empty string to not use wandb.
wandb: str

# HF repo id. Set to empty string to not push to repo.
out_repo: str

# debug config
debug_config: DebugConfig = field(default_factory=DebugConfig)
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/train/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_build_config_from_files_and_overrides():
assert config.eval_iters == 5
# check base values
assert config.max_epochs == 2
assert config.dataset.name == "delphi-suite/v0-tinystories-v2-clean-tokenized"
assert config.dataset.path == "delphi-suite/v0-tinystories-v2-clean-tokenized"


def test_unoptionalize():
Expand Down

0 comments on commit 16f8704

Please sign in to comment.