From 5dae0eb8c05aa3730d4b7dcff06c2dcc839d47d7 Mon Sep 17 00:00:00 2001 From: brianf Date: Tue, 12 Mar 2024 03:33:03 +0000 Subject: [PATCH 1/2] add bnb optim --- dataset/parti.py | 31 +++++++++++++++++++++++++++++-- model/t5.py | 31 +++++++++++++++++++++++-------- model/utils.py | 8 ++++---- requirements.txt | 3 ++- train.py | 21 ++++++++------------- 5 files changed, 66 insertions(+), 28 deletions(-) diff --git a/dataset/parti.py b/dataset/parti.py index 7720cee..3b9472c 100644 --- a/dataset/parti.py +++ b/dataset/parti.py @@ -1,7 +1,7 @@ from datasets import load_dataset from typing import Optional from transformers.tokenization_utils import PreTrainedTokenizer -import os +from typing import Dict from model.utils import ( PROMPT_EXPANSION_TASK_PREFIX, @@ -49,6 +49,17 @@ def setup(self, stage: Optional[str] = None): ensure_directory(self.cache_dir, clear=False) + self.train_dataset = self.train_dataset.filter( + filter_row, + cache_file_name=f"{self.cache_dir}/training_filtered.parquet", + num_proc=self.cpu_count, + ) + self.val_dataset = self.val_dataset.filter( + filter_row, + cache_file_name=f"{self.cache_dir}/validation_filtered.parquet", + num_proc=self.cpu_count, + ) + self.train_dataset = self.train_dataset.map( self.prepare_sample, batched=True, @@ -61,7 +72,7 @@ def setup(self, stage: Optional[str] = None): self.prepare_sample, batched=True, load_from_cache_file=True, - cache_file_name=f"{self.cache_dir}/validation.parquewt", + cache_file_name=f"{self.cache_dir}/validation.parquet", num_proc=self.cpu_count, ) @@ -104,6 +115,19 @@ def prepare_sample(self, examples: dict): } +def filter_row(row: Dict) -> bool: + prompt, upsampled = row["Prompt"], row["Upsampled"] + if len(prompt) == 0 or len(upsampled) == 0: + return False + if "\n" in prompt or "\n" in upsampled: + return False + if len(upsampled.split(" ")) > 10: + return False + if len(upsampled) > 128: + return False + return True + + class PromptSafetyDataModule(PromptUpsampleDataModule): def __init__( self, @@ -117,3 +141,6 @@ def __init__( self.input_column, self.target_column = "Prompt", "Upsampled" self.cache_dir = "dataset_caches/safety_workflows" self.dataset_name = "roborovski/safety-workflows-upsampled" + + def setup(self, stage: Optional[str] = None): + super().setup(stage) diff --git a/model/t5.py b/model/t5.py index e8aeea0..4f761ce 100644 --- a/model/t5.py +++ b/model/t5.py @@ -1,8 +1,13 @@ -from transformers.optimization import Adafactor, get_inverse_sqrt_schedule, AdafactorSchedule +from transformers.optimization import ( + Adafactor, + get_inverse_sqrt_schedule, + AdafactorSchedule, +) from model.utils import HyperParams from torch.optim import AdamW from torch import Tensor from torchmetrics.text.perplexity import Perplexity +import bitsandbytes as bnb import lightning.pytorch as pl @@ -108,16 +113,26 @@ def configure_optimizers(self): "weight_decay": 0.0, }, ] - if self.params.optimizer == "AdamW": - optimizer = AdamW( - optimizer_grouped_parameters, - lr=self.params.learning_rate, - eps=self.params.adam_epsilon, - ) + + optim_choice = self.params.optimizer + + if optim_choice in ["AdamW", "AdamW8bit"]: + if optim_choice == "AdamW": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.params.learning_rate, + eps=self.params.adam_epsilon, + ) + elif optim_choice == "AdamW8bit": + optimizer = bnb.optim.adamw.AdamW8bit( + optimizer_grouped_parameters, + lr=self.params.learning_rate, + eps=self.params.adam_epsilon, + ) scheduler = get_inverse_sqrt_schedule( optimizer, num_warmup_steps=self.params.warmup_steps ) - elif self.params.optimizer == "Adafactor": + elif optim_choice == "Adafactor": optimizer = Adafactor( optimizer_grouped_parameters, lr=self.params.learning_rate, diff --git a/model/utils.py b/model/utils.py index 0cc7a78..8a07db2 100644 --- a/model/utils.py +++ b/model/utils.py @@ -14,7 +14,7 @@ IGNORE_TOKEN_INDEX = -100 PAD_TOKEN_ID = 0 -OptimizerChoice = Literal["AdamW", "Adafactor"] +OptimizerChoice = Literal["AdamW", "Adafactor", "AdamW8bit"] @dataclass class HyperParams: @@ -23,15 +23,15 @@ class HyperParams: learning_rate: float = 3e-4 adam_epsilon: float = 1e-8 warmup_steps: int = 50 - train_batch_size: int = 2 + train_batch_size: int = 4 eval_batch_size: int = 2 num_train_epochs: int = 25 - gradient_accumulation_steps: int = 4 + gradient_accumulation_steps: int = 2 n_gpus: int = 1 max_grad_norm: float = 1.0 seed: int = 42 weight_decay: float = 0.0 - optimizer: OptimizerChoice = "AdamW" + optimizer: OptimizerChoice = "Adafactor" class FineTunerDataset(pl.LightningDataModule): diff --git a/requirements.txt b/requirements.txt index 256ef51..d1d24a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ nltk==3.8.1 diffusers==0.26.3 accelerate matplotlib -openai \ No newline at end of file +openai +bitsandbytes==0.43.0 \ No newline at end of file diff --git a/train.py b/train.py index 8ae82ea..5897e0f 100644 --- a/train.py +++ b/train.py @@ -103,7 +103,7 @@ def log_prediction_samples( new_rows = tabulate( rows, headers=columns, - maxcolwidths=[10, 10, 100, 100, 100], + maxcolwidths=[10, 10, 50, 50, 50], ) print(new_rows) with open(self.log_dir / f"{run_name}_samples.txt", "a") as f: @@ -148,7 +148,8 @@ class ModelConfig: hyperparams: HyperParams = HyperParams() -T5_WANDB_PROJECT = "t5-prompt-upsampling" +PROMPT_UPSAMPLING_PROJECT = "t5-prompt-upsampling" +PROMPT_SAFETY_PROJECT = "t5-prompt-safety" CONFIGS = { "fn_calling": ModelConfig( @@ -160,28 +161,22 @@ class ModelConfig: "prompt_upsample_small": ModelConfig( T5FineTuner, PromptUpsampleDataModule, - T5_WANDB_PROJECT, + PROMPT_UPSAMPLING_PROJECT, HyperParams(base_model_checkpoint="google/flan-t5-small"), ), "prompt_upsample": ModelConfig( T5FineTuner, PromptUpsampleDataModule, - T5_WANDB_PROJECT, + PROMPT_UPSAMPLING_PROJECT, HyperParams(base_model_checkpoint="google/flan-t5-base"), ), - "prompt_upsample_adafactor": ModelConfig( - T5FineTuner, - PromptUpsampleDataModule, - T5_WANDB_PROJECT, - HyperParams(optimizer="Adafactor", learning_rate=1e-3), - ), "prompt_safety": ModelConfig( - T5FineTuner, PromptSafetyDataModule, "t5-prompt-safety" + T5FineTuner, PromptSafetyDataModule, PROMPT_SAFETY_PROJECT, HyperParams(base_model_checkpoint="google/flan-t5-base") ), } -def main(wandb: bool = False, config: str = "prompt_upsample"): +def main(wandb: bool = False, config: str = "prompt_safety"): loggers = [] model_config = CONFIGS[config] @@ -218,7 +213,7 @@ def main(wandb: bool = False, config: str = "prompt_upsample"): max_epochs=hparams.num_train_epochs, precision=precision, gradient_clip_val=hparams.max_grad_norm, - val_check_interval=0.25, + val_check_interval=0.1, callbacks=[sample_callback, checkpoint_callback, progress_bar_callback], logger=loggers, log_every_n_steps=1, From 386dc11535e36e5e12ebdfbdeefc9736307e4c20 Mon Sep 17 00:00:00 2001 From: brianf Date: Tue, 12 Mar 2024 03:55:16 +0000 Subject: [PATCH 2/2] formatting and use adam8bit as default --- model/utils.py | 2 +- train.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/model/utils.py b/model/utils.py index 8a07db2..c2c8baa 100644 --- a/model/utils.py +++ b/model/utils.py @@ -31,7 +31,7 @@ class HyperParams: max_grad_norm: float = 1.0 seed: int = 42 weight_decay: float = 0.0 - optimizer: OptimizerChoice = "Adafactor" + optimizer: OptimizerChoice = "AdamW8bit" class FineTunerDataset(pl.LightningDataModule): diff --git a/train.py b/train.py index 5897e0f..34e83f2 100644 --- a/train.py +++ b/train.py @@ -171,7 +171,10 @@ class ModelConfig: HyperParams(base_model_checkpoint="google/flan-t5-base"), ), "prompt_safety": ModelConfig( - T5FineTuner, PromptSafetyDataModule, PROMPT_SAFETY_PROJECT, HyperParams(base_model_checkpoint="google/flan-t5-base") + T5FineTuner, + PromptSafetyDataModule, + PROMPT_SAFETY_PROJECT, + HyperParams(base_model_checkpoint="google/flan-t5-small"), ), } @@ -213,7 +216,7 @@ def main(wandb: bool = False, config: str = "prompt_safety"): max_epochs=hparams.num_train_epochs, precision=precision, gradient_clip_val=hparams.max_grad_norm, - val_check_interval=0.1, + val_check_interval=0.01, callbacks=[sample_callback, checkpoint_callback, progress_bar_callback], logger=loggers, log_every_n_steps=1,