Skip to content

Commit

Permalink
Merge pull request #6 from brianfitzgerald/t5-safety
Browse files Browse the repository at this point in the history
T5 safety
  • Loading branch information
brianfitzgerald authored Mar 12, 2024
2 parents abd9658 + 386dc11 commit 6cbc0d6
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 28 deletions.
31 changes: 29 additions & 2 deletions dataset/parti.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from typing import Optional
from transformers.tokenization_utils import PreTrainedTokenizer
import os
from typing import Dict

from model.utils import (
PROMPT_EXPANSION_TASK_PREFIX,
Expand Down Expand Up @@ -49,6 +49,17 @@ def setup(self, stage: Optional[str] = None):

ensure_directory(self.cache_dir, clear=False)

self.train_dataset = self.train_dataset.filter(
filter_row,
cache_file_name=f"{self.cache_dir}/training_filtered.parquet",
num_proc=self.cpu_count,
)
self.val_dataset = self.val_dataset.filter(
filter_row,
cache_file_name=f"{self.cache_dir}/validation_filtered.parquet",
num_proc=self.cpu_count,
)

self.train_dataset = self.train_dataset.map(
self.prepare_sample,
batched=True,
Expand All @@ -61,7 +72,7 @@ def setup(self, stage: Optional[str] = None):
self.prepare_sample,
batched=True,
load_from_cache_file=True,
cache_file_name=f"{self.cache_dir}/validation.parquewt",
cache_file_name=f"{self.cache_dir}/validation.parquet",
num_proc=self.cpu_count,
)

Expand Down Expand Up @@ -104,6 +115,19 @@ def prepare_sample(self, examples: dict):
}


def filter_row(row: Dict) -> bool:
prompt, upsampled = row["Prompt"], row["Upsampled"]
if len(prompt) == 0 or len(upsampled) == 0:
return False
if "\n" in prompt or "\n" in upsampled:
return False
if len(upsampled.split(" ")) > 10:
return False
if len(upsampled) > 128:
return False
return True


class PromptSafetyDataModule(PromptUpsampleDataModule):
def __init__(
self,
Expand All @@ -117,3 +141,6 @@ def __init__(
self.input_column, self.target_column = "Prompt", "Upsampled"
self.cache_dir = "dataset_caches/safety_workflows"
self.dataset_name = "roborovski/safety-workflows-upsampled"

def setup(self, stage: Optional[str] = None):
super().setup(stage)
31 changes: 23 additions & 8 deletions model/t5.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from transformers.optimization import Adafactor, get_inverse_sqrt_schedule, AdafactorSchedule
from transformers.optimization import (
Adafactor,
get_inverse_sqrt_schedule,
AdafactorSchedule,
)
from model.utils import HyperParams
from torch.optim import AdamW
from torch import Tensor
from torchmetrics.text.perplexity import Perplexity
import bitsandbytes as bnb

import lightning.pytorch as pl

Expand Down Expand Up @@ -108,16 +113,26 @@ def configure_optimizers(self):
"weight_decay": 0.0,
},
]
if self.params.optimizer == "AdamW":
optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.params.learning_rate,
eps=self.params.adam_epsilon,
)

optim_choice = self.params.optimizer

if optim_choice in ["AdamW", "AdamW8bit"]:
if optim_choice == "AdamW":
optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.params.learning_rate,
eps=self.params.adam_epsilon,
)
elif optim_choice == "AdamW8bit":
optimizer = bnb.optim.adamw.AdamW8bit(
optimizer_grouped_parameters,
lr=self.params.learning_rate,
eps=self.params.adam_epsilon,
)
scheduler = get_inverse_sqrt_schedule(
optimizer, num_warmup_steps=self.params.warmup_steps
)
elif self.params.optimizer == "Adafactor":
elif optim_choice == "Adafactor":
optimizer = Adafactor(
optimizer_grouped_parameters,
lr=self.params.learning_rate,
Expand Down
8 changes: 4 additions & 4 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
IGNORE_TOKEN_INDEX = -100
PAD_TOKEN_ID = 0

OptimizerChoice = Literal["AdamW", "Adafactor"]
OptimizerChoice = Literal["AdamW", "Adafactor", "AdamW8bit"]

@dataclass
class HyperParams:
Expand All @@ -23,15 +23,15 @@ class HyperParams:
learning_rate: float = 3e-4
adam_epsilon: float = 1e-8
warmup_steps: int = 50
train_batch_size: int = 2
train_batch_size: int = 4
eval_batch_size: int = 2
num_train_epochs: int = 25
gradient_accumulation_steps: int = 4
gradient_accumulation_steps: int = 2
n_gpus: int = 1
max_grad_norm: float = 1.0
seed: int = 42
weight_decay: float = 0.0
optimizer: OptimizerChoice = "AdamW"
optimizer: OptimizerChoice = "AdamW8bit"


class FineTunerDataset(pl.LightningDataModule):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ nltk==3.8.1
diffusers==0.26.3
accelerate
matplotlib
openai
openai
bitsandbytes==0.43.0
24 changes: 11 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def log_prediction_samples(
new_rows = tabulate(
rows,
headers=columns,
maxcolwidths=[10, 10, 100, 100, 100],
maxcolwidths=[10, 10, 50, 50, 50],
)
print(new_rows)
with open(self.log_dir / f"{run_name}_samples.txt", "a") as f:
Expand Down Expand Up @@ -148,7 +148,8 @@ class ModelConfig:
hyperparams: HyperParams = HyperParams()


T5_WANDB_PROJECT = "t5-prompt-upsampling"
PROMPT_UPSAMPLING_PROJECT = "t5-prompt-upsampling"
PROMPT_SAFETY_PROJECT = "t5-prompt-safety"

CONFIGS = {
"fn_calling": ModelConfig(
Expand All @@ -160,28 +161,25 @@ class ModelConfig:
"prompt_upsample_small": ModelConfig(
T5FineTuner,
PromptUpsampleDataModule,
T5_WANDB_PROJECT,
PROMPT_UPSAMPLING_PROJECT,
HyperParams(base_model_checkpoint="google/flan-t5-small"),
),
"prompt_upsample": ModelConfig(
T5FineTuner,
PromptUpsampleDataModule,
T5_WANDB_PROJECT,
PROMPT_UPSAMPLING_PROJECT,
HyperParams(base_model_checkpoint="google/flan-t5-base"),
),
"prompt_upsample_adafactor": ModelConfig(
T5FineTuner,
PromptUpsampleDataModule,
T5_WANDB_PROJECT,
HyperParams(optimizer="Adafactor", learning_rate=1e-3),
),
"prompt_safety": ModelConfig(
T5FineTuner, PromptSafetyDataModule, "t5-prompt-safety"
T5FineTuner,
PromptSafetyDataModule,
PROMPT_SAFETY_PROJECT,
HyperParams(base_model_checkpoint="google/flan-t5-small"),
),
}


def main(wandb: bool = False, config: str = "prompt_upsample"):
def main(wandb: bool = False, config: str = "prompt_safety"):
loggers = []

model_config = CONFIGS[config]
Expand Down Expand Up @@ -218,7 +216,7 @@ def main(wandb: bool = False, config: str = "prompt_upsample"):
max_epochs=hparams.num_train_epochs,
precision=precision,
gradient_clip_val=hparams.max_grad_norm,
val_check_interval=0.25,
val_check_interval=0.01,
callbacks=[sample_callback, checkpoint_callback, progress_bar_callback],
logger=loggers,
log_every_n_steps=1,
Expand Down

0 comments on commit 6cbc0d6

Please sign in to comment.