Skip to content

Commit

Permalink
formatting and use adam8bit as default
Browse files Browse the repository at this point in the history
  • Loading branch information
brianfitzgerald committed Mar 12, 2024
1 parent 5dae0eb commit 386dc11
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class HyperParams:
max_grad_norm: float = 1.0
seed: int = 42
weight_decay: float = 0.0
optimizer: OptimizerChoice = "Adafactor"
optimizer: OptimizerChoice = "AdamW8bit"


class FineTunerDataset(pl.LightningDataModule):
Expand Down
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ class ModelConfig:
HyperParams(base_model_checkpoint="google/flan-t5-base"),
),
"prompt_safety": ModelConfig(
T5FineTuner, PromptSafetyDataModule, PROMPT_SAFETY_PROJECT, HyperParams(base_model_checkpoint="google/flan-t5-base")
T5FineTuner,
PromptSafetyDataModule,
PROMPT_SAFETY_PROJECT,
HyperParams(base_model_checkpoint="google/flan-t5-small"),
),
}

Expand Down Expand Up @@ -213,7 +216,7 @@ def main(wandb: bool = False, config: str = "prompt_safety"):
max_epochs=hparams.num_train_epochs,
precision=precision,
gradient_clip_val=hparams.max_grad_norm,
val_check_interval=0.1,
val_check_interval=0.01,
callbacks=[sample_callback, checkpoint_callback, progress_bar_callback],
logger=loggers,
log_every_n_steps=1,
Expand Down

0 comments on commit 386dc11

Please sign in to comment.