diff --git a/examples/detoxification_bias/Detoxify_LM_demo.ipynb b/examples/detoxification_bias/Detoxify_LM_demo.ipynb new file mode 100644 index 0000000..198fa8a --- /dev/null +++ b/examples/detoxification_bias/Detoxify_LM_demo.ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Detoxifying Language models and Evaluating Social Bias\n", + "\n", + "This notebook serves as a demo of paritial experiments in the paper [Detoxifying Language Models Risks Marginalizing Minority Voices]((https://aclanthology.org/2021.naacl-main.190)) using the **llments** framework. We mainly perform domain-adaptive fine-tuning as the detoxification approach, and use perplexity to quantify the social bias of detoxified models with respect to language styles of different demographic groups." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Preprocessing training data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preprocess training data for fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done reading\n", + "FT Data Done\n" + ] + } + ], + "source": [ + "path = \"data/raw/civilcomments/train.csv\"\n", + "ft_output = \"data/train/ft\"\n", + "pt_output = \"data/train/pt\"\n", + "\n", + "input_df = pd.read_csv(path)\n", + "print(\"Done reading\")\n", + "\n", + "class_sample_df = input_df[[\"target\", \"comment_text\"]]\n", + "# a missing step in the orignal code to remove the null values\n", + "class_sample_df = class_sample_df[class_sample_df.comment_text.notnull()]\n", + "class_sample_df = class_sample_df[(class_sample_df.target >= 0.5) | (class_sample_df.target < 0.1)]\n", + "class_sample_df[\"target\"] = (class_sample_df[\"target\"] >= 0.1).astype(int)\n", + "class_sample_df[\"comment_text\"] = class_sample_df[\"comment_text\"].apply(lambda x: x.replace(\"\\n\", \"\").replace(\"\\r\", \"\").replace('\\t', \"\"))\n", + "\n", + "# save the finetuning data\n", + "finetuning_df = class_sample_df[class_sample_df.target == 0]\n", + "finetuning_df = finetuning_df[[\"comment_text\"]]\n", + "ft_train, ft_valid = np.split(finetuning_df, [int(0.9*len(finetuning_df))])\n", + "ft_train.to_csv(os.path.join(ft_output, \"train.tsv\"), sep=\"\\t\", header=False, index=False)\n", + "ft_valid.to_csv(os.path.join(ft_output, \"valid.tsv\"), sep=\"\\t\", header=False, index=False)\n", + "print(\"FT Data Done\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read in the pretraining data and check the column names\n", + "pt_train = pd.read_csv(os.path.join(pt_output, \"train.tsv\"), sep=\"\\t\", header=None)\n", + "ft_train = pd.read_csv(os.path.join(ft_output, \"train.tsv\"), sep=\"\\t\", header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1269504\n", + "1139603\n" + ] + } + ], + "source": [ + "print(len(pt_train))\n", + "print(len(ft_train))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scoring evaluation data: WAE vs. AAE" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "from llments.eval.toxicity import ToxicityEvaluator\n", + "# create a toxicity evaluator for text scoring\n", + "# api_key = \"AIzaSyA0RUal_V-LnJ949JlxylIzLKv6l87zh5U\"\n", + "api_key = \"PASTE_YOUR_API\"\n", + "toxicity_evaluator = ToxicityEvaluator(api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "# extract content from raw text files\n", + "wae_samples = []\n", + "aave_samples = []\n", + "\n", + "with open(\"data/raw/sae_samples.txt\", \"r\") as f:\n", + " for line in f:\n", + " wae_samples.append(line.strip())\n", + "\n", + "with open(\"data/raw/aave_samples.txt\", \"r\") as f:\n", + " for line in f:\n", + " aave_samples.append(line.strip())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# evaluate the sae samples\n", + "wae_scores = toxicity_evaluator.evaluate_batch(wae_samples, show_progress=True)\n", + "aave_scores = toxicity_evaluator.evaluate_batch(aave_samples, show_progress=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# construct a dataframe for the sae samples\n", + "wae_df = pd.DataFrame({\"text\": wae_samples, \"toxicity\": wae_scores})\n", + "aave_df = pd.DataFrame({\"text\": aave_samples, \"toxicity\": aave_scores})" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "# save it to a jsonl file\n", + "target_dir = \"data/eval/translation_pairs/scored\"\n", + "wae_df.to_json(f\"{target_dir}/wae_samples_scores.jsonl\", orient=\"records\", force_ascii=False, lines=True)\n", + "aave_df.to_json(f\"{target_dir}/aave_samples_scores.jsonl\", orient=\"records\", force_ascii=False, lines=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filtering evaluation data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wrote data/eval/translation_pairs/filtered/nontoxic_aae.txt\n", + "wrote data/eval/translation_pairs/filtered/nontoxic_wae.txt\n", + "wrote data/eval/translation_pairs/filtered/toxic_aae.txt\n", + "wrote data/eval/translation_pairs/filtered/toxic_wae.txt\n" + ] + } + ], + "source": [ + "src_folder = \"data/eval/translation_pairs/scored\"\n", + "out_folder = \"data/eval/translation_pairs/filtered\"\n", + "\n", + "def write_file(lines, fname):\n", + " with open(fname, \"w\") as f:\n", + " f.write(\"\\n\".join([l.replace(\"\\n\", \" \") for l in lines]))\n", + " print(\"wrote {}\".format(fname))\n", + "\n", + "aae_df = pd.read_json(os.path.join(src_folder, \"aave_samples_scores.jsonl\"), lines=True)\n", + "aae_df = aae_df.rename(columns={'text': 'AAE_text', 'toxicity': 'AAE_toxicity'})\n", + "muse_df = pd.read_json(os.path.join(src_folder, \"wae_samples_scores.jsonl\"), lines=True)\n", + "muse_df = muse_df.rename(columns={'text': 'WAE_text', 'toxicity': 'WAE_toxicity'})\n", + "cat = pd.concat([aae_df, muse_df], axis=1)\n", + "\n", + "nontoxic_df = cat[(cat['AAE_toxicity'] < 0.5) | (cat['WAE_toxicity'] < 0.5)]\n", + "toxic_df = cat[(cat['AAE_toxicity'] > 0.5) & (cat['WAE_toxicity'] > 0.5)]\n", + "\n", + "# Write the full sentences\n", + "write_file(nontoxic_df[\"AAE_text\"], os.path.join(out_folder, \"nontoxic_aae.txt\"))\n", + "write_file(nontoxic_df[\"WAE_text\"], os.path.join(out_folder, \"nontoxic_wae.txt\"))\n", + "write_file(toxic_df[\"AAE_text\"], os.path.join(out_folder, \"toxic_aae.txt\"))\n", + "write_file(toxic_df[\"WAE_text\"], os.path.join(out_folder, \"toxic_wae.txt\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Detoxification: Fine-Tuning w/ Non-Toxic Data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# load dataset\n", + "train_file = \"data/train/ft/train.tsv\"\n", + "eval_file = \"data/train/ft/valid.tsv\"\n", + "\n", + "# load the tsv as pandas dataframe, each line is ann entry under the column \"text\"\n", + "train_df = pd.read_csv(train_file, sep=\"\\t\", header=None, names=[\"text\"])\n", + "eval_df = pd.read_csv(eval_file, sep=\"\\t\", header=None, names=[\"text\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training samples: 1139603\n", + "Number of validation samples: 126623\n" + ] + } + ], + "source": [ + "print(\"Number of training samples: \", len(train_df))\n", + "print(\"Number of validation samples: \", len(eval_df))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maybe those people should realize that this is 21st century America...not 1700's frontier days.Join civilization or move somewhere without any.\n", + "Great First Lady, I respect her for her statement. More Presidential than her husband.\n", + "Same thing in Calgary. This is the only sensible thing the government could do. It's a shame for the former residents to be uprooted, and stately neighbourhood homes moved or demolished. In retrospect (always 20/20) approval for building on vulnerable flood plains should never have been granted.\n", + "Now das da kine house I want, real old school. I grew up around Palolo, Ka'imuki, Kapahulu, Mo'ili'ili so it brings back planny memories. Can smell da mosquito punk driffin' out through da bedroom window...This one is kinda city-version anyway. Check da roof -- shingles, stedda totan. Nobody going shishi da pants when one big green mango fa'down BLAM! on top da tin roof middle of da night.\n", + "Thanks. You do realize that pattyjane doesn't agree. Thanks.\n" + ] + } + ], + "source": [ + "samples = train_df[\"text\"][:5].to_list()\n", + "for sample in samples:\n", + " print(sample)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llments.lm.base.hugging_face import HuggingFaceLM, HuggingFaceLMFitter\n", + "from llments.lm.base.dataset_lm import DatasetLM\n", + "\n", + "# load base models and datasets we use for finetuning\n", + "base_gpt2_lm = HuggingFaceLM(model='gpt2')\n", + "base_llama2_lm = HuggingFaceLM(model='NousResearch/Llama-2-7b-hf')\n", + "train_dataset_lm = DatasetLM(train_df[\"text\"].to_list())\n", + "eval_dataset_lm = DatasetLM(eval_df[\"text\"].to_list())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# fit a GPT2 model on the finetuning dataset\n", + "fitted_gpt2_lm = HuggingFaceLMFitter.fit(\n", + " base=base_gpt2_lm,\n", + " target=train_dataset_lm,\n", + " eval_target=eval_dataset_lm,\n", + " output_dir=\"checkpoints/gpt2\",\n", + " logging_dir=\"logs/gpt2\",\n", + " batch_size=4, \n", + " training_steps=22500,\n", + " eval_steps=500,\n", + " logging_steps=500,\n", + " save_steps=2500,\n", + " do_train=True, \n", + " do_eval=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# fit a Llama2 model using LORA efficent finetuning\n", + "fitted_llama2_lm = HuggingFaceLMFitter.fit(\n", + " base=base_llama2_lm,\n", + " target=train_dataset_lm,\n", + " eval_target=eval_dataset_lm,\n", + " output_dir=\"checkpoints/Llama2-7b\",\n", + " logging_dir=\"logs/Llama2-7b\",\n", + " batch_size=4,\n", + " training_steps=22500,\n", + " eval_steps=500,\n", + " logging_steps=500,\n", + " save_steps=2500,\n", + " do_train=True,\n", + " do_eval=True, \n", + " lora_alpha=32,\n", + " lora_r=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Perplexity evaluation on WAE vs. AAE\n", + "\n", + "### This part still awaits replacement with llment code" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling\n", + "from datasets import load_dataset\n", + "import math\n", + "\n", + "# helper function to evaluate the perplexity of a fine-tuned model\n", + "def eval_ppl(data_dir, eval_files, model_path, model_name, training_args, output_dir):\n", + " \n", + " checkpoint = HuggingFaceLM(model=model_path, tokenizer_path=model_name)\n", + " res = []\n", + "\n", + " # TODO: replace the following with llments operators if available, \n", + " # consider a HGLM.evaluate() operator\n", + " for eval_file in eval_files:\n", + " eval_file_path = os.path.join(data_dir, eval_file)\n", + " eval_dataset = load_dataset(\"text\", data_files=eval_file_path, split=\"train\")\n", + " \n", + " eval_dataset = eval_dataset.map(lambda examples: checkpoint.tokenizer(\n", + " examples[\"text\"], truncation=True, padding=\"max_length\", max_length=128), batched=True)\n", + "\n", + " data_collator = DataCollatorForLanguageModeling(tokenizer=checkpoint.tokenizer, mlm=False)\n", + "\n", + " trainer = Trainer(\n", + " model=checkpoint.model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " eval_dataset=eval_dataset\n", + " )\n", + " eval_results = trainer.evaluate()\n", + "\n", + " # calculate the perplexity\n", + " ppl = math.exp(eval_results[\"eval_loss\"])\n", + " res.append((eval_file, ppl))\n", + "\n", + " with open(f\"{output_dir}/{model_path.split('/')[-1]}.txt\", \"w\") as f:\n", + " for r in res:\n", + " f.write(f\"{r[0]}: {r[1]}\\n\")\n", + "\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# helper function to plot the perplexity scores\n", + "def plot_ppl(eval_file_dir, checkpoint_paths, model_name):\n", + "\n", + " data = []\n", + "\n", + " for checkpoint_path in checkpoint_paths:\n", + " checkpoint_name = checkpoint_path.split('/')[-1]\n", + "\n", + " # read in the perplexity scores\n", + " with open(f\"{eval_file_dir}/{checkpoint_name}.txt\", \"r\") as f:\n", + " lines = f.readlines()\n", + " for line in lines:\n", + " eval_file, ppl = line.split(\":\")\n", + " eval_file = eval_file.split(\".\")[0]\n", + " ppl = float(ppl)\n", + " train_step = float(checkpoint_name.split(\"-\")[-1]) * 3\n", + " data.append((train_step, eval_file, ppl))\n", + " \n", + " df = pd.DataFrame(data, columns=[\"Training Step\", \"Eval File\", \"Perplexity\"])\n", + "\n", + " # sort the df by a predefined order\n", + " eval_order = [\"nontoxic_aae\", \"toxic_aae\", \"nontoxic_wae\", \"toxic_wae\"]\n", + " df[\"Eval File\"] = pd.Categorical(df[\"Eval File\"], categories=eval_order, ordered=True)\n", + " \n", + " plt.figure(figsize=(12, 8)) # Increased figure size for better detail visibility\n", + " ax = sns.lineplot(x=\"Training Step\", y=\"Perplexity\", hue=\"Eval File\", data=df, palette=\"Paired\", marker=\"o\", linewidth=2.5, markersize=8)\n", + " plt.title(f\"Perplexity Scores for {model_name} DAPT Checkpoints on AAVE and WAE Samples\", fontsize=16, fontweight='bold', color='navy')\n", + " plt.xlabel(\"Training Step\", fontsize=14, fontweight='bold', color='darkgreen')\n", + " plt.ylabel(\"Perplexity\", fontsize=14, fontweight='bold', color='darkgreen')\n", + " plt.legend(title=\"Eval File\", title_fontsize='13', fontsize='11', frameon=True, shadow=True, borderpad=1)\n", + " plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n", + " plt.tight_layout()\n", + "\n", + " for line in ax.lines:\n", + " for x, y in zip(line.get_xdata(), line.get_ydata()):\n", + " plt.text(x, y, f'{y:.2f}', color=line.get_color(), fontsize=12, verticalalignment='bottom')\n", + "\n", + " plt.savefig(f\"{model_name}_perplexity_scores_plot.png\", dpi=300) # Save as PNG image with high resolution\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "eval_data_dir = \"data/eval/translation_pairs/filtered\"\n", + "eval_files = [\"nontoxic_aae.txt\", \"toxic_aae.txt\", \"nontoxic_wae.txt\", \"toxic_wae.txt\"]\n", + "\n", + "# evaluate the GPT2 model\n", + "if not os.path.exists(\"eval_results_gpt2\"):\n", + " os.makedirs(\"eval_results_gpt2\")\n", + "\n", + "if not os.path.exists(\"trash\"):\n", + " os.makedirs(\"trash\")\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir=\"trash\",\n", + " per_device_eval_batch_size=1,\n", + " do_train=False, \n", + " do_eval=True, \n", + " fp16=False,\n", + ")\n", + "\n", + "model_path_list = [\"gpt2\",\n", + " \"checkpoints/gpt2/checkpoint-2500\",\n", + " \"checkpoints/gpt2/checkpoint-5000\",\n", + " \"checkpoints/gpt2/checkpoint-7500\",\n", + " \"checkpoints/gpt2/checkpoint-10000\",\n", + " \"checkpoints/gpt2/checkpoint-12500\",\n", + " \"checkpoints/gpt2/checkpoint-15000\",\n", + " \"checkpoints/gpt2/checkpoint-17500\",\n", + " \"checkpoints/gpt2/checkpoint-20000\",\n", + " \"checkpoints/gpt2/checkpoint-22500\"]\n", + "\n", + "for model_path in model_path_list:\n", + " res = eval_ppl(eval_data_dir, eval_files, model_path, \"gpt2\", training_args, \"eval_results_gpt2\")\n", + " print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plot the perplexity scores\n", + "plot_ppl(\"eval_results_gpt2\", model_path_list, \"gpt2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# evaluate the Llama2 model\n", + "if not os.path.exists(\"eval_results_Llama2-7b\"):\n", + " os.makedirs(\"eval_results_Llama2-7b\")\n", + " \n", + "model_path_list = [\"NousResearch/Llama-2-7b-hf\",\n", + " \"checkpoints/Llama2-7b/checkpoint-2500\",\n", + " \"checkpoints/Llama2-7b/checkpoint-5000\",\n", + " \"checkpoints/Llama2-7b/checkpoint-7500\",\n", + " \"checkpoints/Llama2-7b/checkpoint-10000\",\n", + " \"checkpoints/Llama2-7b/checkpoint-12500\",\n", + " \"checkpoints/Llama2-7b/checkpoint-15000\",\n", + " \"checkpoints/Llama2-7b/checkpoint-17500\",\n", + " \"checkpoints/Llama2-7b/checkpoint-20000\",\n", + " \"checkpoints/Llama2-7b/checkpoint-22500\"]\n", + "\n", + "for model_path in model_path_list:\n", + " res = eval_ppl(eval_data_dir, eval_files, model_path, \n", + " \"NousResearch/Llama-2-7b-hf\", training_args, \"eval_results_Llama2-7b\")\n", + " print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# plot the perplexity scores\n", + "plot_ppl(\"eval_results_Llama2-7b\", model_path_list, \"Llama2-7b\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "detox-rep", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/detoxification_bias/Llama2-7b_perplexity_scores_plot.png b/examples/detoxification_bias/Llama2-7b_perplexity_scores_plot.png new file mode 100644 index 0000000..ac10c39 Binary files /dev/null and b/examples/detoxification_bias/Llama2-7b_perplexity_scores_plot.png differ diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-0.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-0.txt new file mode 100644 index 0000000..ff5fcc4 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-0.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 69.70890155366865 +nontoxic_wae.txt: 32.66969243018137 +toxic_aae.txt: 78.83975190870724 +toxic_wae.txt: 34.361008745995235 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-10000.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-10000.txt new file mode 100644 index 0000000..6504a70 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-10000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.54182914696298 +nontoxic_wae.txt: 31.075359169048475 +toxic_aae.txt: 82.0036723039959 +toxic_wae.txt: 35.49998927643886 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-12500.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-12500.txt new file mode 100644 index 0000000..e2e9bae --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-12500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.77815390265802 +nontoxic_wae.txt: 31.276970532556494 +toxic_aae.txt: 81.52509234050257 +toxic_wae.txt: 35.5350977767934 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-15000.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-15000.txt new file mode 100644 index 0000000..26fa191 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-15000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 68.3272057102135 +nontoxic_wae.txt: 31.38248585097714 +toxic_aae.txt: 82.60140889429582 +toxic_wae.txt: 35.838112321343196 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-17500.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-17500.txt new file mode 100644 index 0000000..f3b4e41 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-17500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.59234775016954 +nontoxic_wae.txt: 31.144420463872386 +toxic_aae.txt: 81.38755496908901 +toxic_wae.txt: 35.4503242815833 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-20000.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-20000.txt new file mode 100644 index 0000000..c6b6cd1 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-20000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.86184747351325 +nontoxic_wae.txt: 31.27768641379288 +toxic_aae.txt: 81.90585738545141 +toxic_wae.txt: 35.637521834613224 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-22500.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-22500.txt new file mode 100644 index 0000000..dc37090 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-22500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.74516408958367 +nontoxic_wae.txt: 31.215354367901096 +toxic_aae.txt: 81.79559915314613 +toxic_wae.txt: 35.587613521207594 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-2500.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-2500.txt new file mode 100644 index 0000000..6b3b938 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-2500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 66.6321874328598 +nontoxic_wae.txt: 30.936692524134642 +toxic_aae.txt: 79.52587340512012 +toxic_wae.txt: 34.702246590550345 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-5000.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-5000.txt new file mode 100644 index 0000000..5e071f6 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-5000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 67.37626578149742 +nontoxic_wae.txt: 30.970795258713224 +toxic_aae.txt: 80.95880530005441 +toxic_wae.txt: 34.90532807304872 diff --git a/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-7500.txt b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-7500.txt new file mode 100644 index 0000000..0bc5cb8 --- /dev/null +++ b/examples/detoxification_bias/eval_results_Llama2-7b/checkpoint-7500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 66.95000277502696 +nontoxic_wae.txt: 31.028843628780926 +toxic_aae.txt: 80.02690115281109 +toxic_wae.txt: 34.852323428523675 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-0.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-0.txt new file mode 100644 index 0000000..418682a --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-0.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 264.5613111827735 +toxic_aae.txt: 354.88802955429094 +nontoxic_wae.txt: 79.88636766808261 +toxic_wae.txt: 87.09986385246668 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-10000.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-10000.txt new file mode 100644 index 0000000..289eddd --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-10000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 283.09457814161146 +toxic_aae.txt: 519.3137162452938 +nontoxic_wae.txt: 76.74805248230875 +toxic_wae.txt: 115.84195781580442 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-12500.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-12500.txt new file mode 100644 index 0000000..d5c7d42 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-12500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 287.60340638140406 +toxic_aae.txt: 533.0257307525487 +nontoxic_wae.txt: 77.95483700973249 +toxic_wae.txt: 118.03557713526929 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-15000.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-15000.txt new file mode 100644 index 0000000..fec9ac5 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-15000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 286.83726753582715 +toxic_aae.txt: 539.255382246365 +nontoxic_wae.txt: 77.31839234349853 +toxic_wae.txt: 117.18176009424498 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-17500.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-17500.txt new file mode 100644 index 0000000..6929e48 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-17500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 290.3258226022549 +toxic_aae.txt: 547.6406040700206 +nontoxic_wae.txt: 77.65640006849367 +toxic_wae.txt: 118.1359174010061 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-20000.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-20000.txt new file mode 100644 index 0000000..f79374f --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-20000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 288.0890171791605 +toxic_aae.txt: 540.2364799077442 +nontoxic_wae.txt: 77.36334784409642 +toxic_wae.txt: 117.19003012544398 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-22500.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-22500.txt new file mode 100644 index 0000000..c871427 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-22500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 290.12460330926433 +toxic_aae.txt: 546.1822801277702 +nontoxic_wae.txt: 77.5023623843882 +toxic_wae.txt: 117.62872928935359 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-2500.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-2500.txt new file mode 100644 index 0000000..30e95bc --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-2500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 259.00381827232985 +toxic_aae.txt: 410.7091123707273 +nontoxic_wae.txt: 72.72051108279676 +toxic_wae.txt: 92.81906367281663 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-5000.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-5000.txt new file mode 100644 index 0000000..05a7998 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-5000.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 258.6714391036896 +toxic_aae.txt: 424.86183050956515 +nontoxic_wae.txt: 74.07619621989392 +toxic_wae.txt: 100.88206485566259 diff --git a/examples/detoxification_bias/eval_results_gpt2/checkpoint-7500.txt b/examples/detoxification_bias/eval_results_gpt2/checkpoint-7500.txt new file mode 100644 index 0000000..e14b1e9 --- /dev/null +++ b/examples/detoxification_bias/eval_results_gpt2/checkpoint-7500.txt @@ -0,0 +1,4 @@ +nontoxic_aae.txt: 268.05961816776534 +toxic_aae.txt: 473.6350588349334 +nontoxic_wae.txt: 75.44938313962443 +toxic_wae.txt: 111.1346968133227 diff --git a/examples/detoxification_bias/gpt2_perplexity_scores_plot.png b/examples/detoxification_bias/gpt2_perplexity_scores_plot.png new file mode 100644 index 0000000..2c96a27 Binary files /dev/null and b/examples/detoxification_bias/gpt2_perplexity_scores_plot.png differ diff --git a/examples/detoxification_bias/replication.ipynb b/examples/detoxification_bias/replication.ipynb deleted file mode 100644 index 1cd3e20..0000000 --- a/examples/detoxification_bias/replication.ipynb +++ /dev/null @@ -1,1821 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import numpy as np\n", - "import pandas as pd\n", - "import csv\n", - "import os\n", - "import time\n", - "import json\n", - "import requests\n", - "from tqdm import tqdm, trange" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Preprocessing training data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training data for detoxifying methods" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Done reading\n", - "FT Data Done\n" - ] - } - ], - "source": [ - "path = \"data/raw/civilcomments/train.csv\"\n", - "pplm_output = \"data/train/pplm\"\n", - "gedi_output = \"data/train/gedi\"\n", - "ft_output = \"data/train/ft\"\n", - "pt_output = \"data/train/pt\"\n", - "\n", - "input_df = pd.read_csv(path)\n", - "print(\"Done reading\")\n", - "\n", - "class_sample_df = input_df[[\"target\", \"comment_text\"]]\n", - "# a missing step in the orignal code to remove the null values\n", - "class_sample_df = class_sample_df[class_sample_df.comment_text.notnull()]\n", - "class_sample_df = class_sample_df[(class_sample_df.target >= 0.5) | (class_sample_df.target < 0.1)]\n", - "class_sample_df[\"target\"] = (class_sample_df[\"target\"] >= 0.1).astype(int)\n", - "class_sample_df[\"comment_text\"] = class_sample_df[\"comment_text\"].apply(lambda x: x.replace(\"\\n\", \"\").replace(\"\\r\", \"\").replace('\\t', \"\"))\n", - "\n", - "## save the pplm and gedi data\n", - "# class_sample_df.to_csv(os.path.join(pplm_output, \"train.tsv\"), sep=\"\\t\", header=False, index=False)\n", - "# print(\"PPLM Data Done\")\n", - "\n", - "# class_sample_df_swapped = class_sample_df[[\"comment_text\", \"target\"]]\n", - "# class_sample_df_swapped[\"target\"] = class_sample_df_swapped[\"target\"].apply(lambda x: 1 - x)\n", - "# gedi_train, gedi_valid = np.split(class_sample_df_swapped, [int(0.9*len(class_sample_df_swapped))])\n", - "# gedi_train.to_csv(os.path.join(gedi_output, \"train.tsv\"), sep=\"\\t\", header=False, index=False)\n", - "# gedi_valid.to_csv(os.path.join(gedi_output, \"valid.tsv\"), sep=\"\\t\", header=False, index=False)\n", - "# print(\"GeDi Data Done\")\n", - "\n", - "# save the finetuning data\n", - "finetuning_df = class_sample_df[class_sample_df.target == 0]\n", - "finetuning_df = finetuning_df[[\"comment_text\"]]\n", - "ft_train, ft_valid = np.split(finetuning_df, [int(0.9*len(finetuning_df))])\n", - "ft_train.to_csv(os.path.join(ft_output, \"train.tsv\"), sep=\"\\t\", header=False, index=False)\n", - "ft_valid.to_csv(os.path.join(ft_output, \"valid.tsv\"), sep=\"\\t\", header=False, index=False)\n", - "print(\"FT Data Done\")" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 0\n", - "0 This is so cool. It's like, 'would you want yo...\n", - "1 Thank you!! This would make my life a lot less...\n", - "2 This is such an urgent design problem; kudos t...\n", - "3 Is this something I'll be able to install on m...\n", - "4 haha you guys are a bunch of losers.\n" - ] - } - ], - "source": [ - "# read in the pretraining data and check the column names\n", - "pt_train = pd.read_csv(os.path.join(pt_output, \"train.tsv\"), sep=\"\\t\", header=None)\n", - "print(pt_train.head())" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 0\n", - "0 This is so cool. It's like, 'would you want yo...\n", - "1 Thank you!! This would make my life a lot less...\n", - "2 This is such an urgent design problem; kudos t...\n", - "3 Is this something I'll be able to install on m...\n", - "4 FFFFUUUUUUUUUUUUUUU\n" - ] - } - ], - "source": [ - "ft_train = pd.read_csv(os.path.join(ft_output, \"train.tsv\"), sep=\"\\t\", header=None)\n", - "print(ft_train.head())" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1269504\n", - "1139603\n" - ] - } - ], - "source": [ - "print(len(pt_train))\n", - "print(len(ft_train))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Scoring evaluation data: WAE vs. AAE" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "api_key = \"AIzaSyA0RUal_V-LnJ949JlxylIzLKv6l87zh5U\"\n", - "# api_key = \"PASTE_YOUR_API_KEY_HERE\"\n", - "url = (\"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze\" + \"?key=\" + api_key)\n", - "\n", - "def score_generations(generations_filename, scores_filename):\n", - " categories = [\"toxicity\", \"severe_toxicity\", \"insult\", \"flirtation\", \\\n", - " \"profanity\", \"sexually_explicit\", \"threat\", \"identity_attack\"]\n", - " generations_columns = [\"text\"] + categories\n", - "\n", - " if os.path.exists(scores_filename):\n", - " generations_df = pd.read_json(scores_filename, lines=True)\n", - " start_row = generations_df.shape[0]\n", - " else:\n", - " generations_df = pd.DataFrame(columns=generations_columns)\n", - " start_row = 0\n", - "\n", - " with open(generations_filename) as f:\n", - " for num_requests, line in tqdm(enumerate(f)):\n", - " # sleep to avoid exceeding quota limit\n", - " time.sleep(1.2)\n", - " \n", - " if num_requests == 5000:\n", - " break\n", - " if num_requests < start_row:\n", - " continue\n", - "\n", - " if num_requests % 500 == 0:\n", - " print(\"{} | Request #{}\".format(generations_filename, num_requests))\n", - " # to checkpoint because it randomly errors out sometimes, losing all of the progress\n", - " generations_df.to_json(scores_filename, orient=\"records\", force_ascii=False, lines=True)\n", - "\n", - " if line[-1] == \"\\n\":\n", - " line = line[:-1]\n", - "\n", - " data_dict = {\"comment\": {\"text\": line},\n", - " \"languages\": [\"en\"],\n", - " \"requestedAttributes\": {\"TOXICITY\":{}, \"SEVERE_TOXICITY\":{}, \"INSULT\":{}, \\\n", - " \"FLIRTATION\":{}, \"PROFANITY\":{}, \"SEXUALLY_EXPLICIT\":{}, \\\n", - " \"THREAT\":{}, \"IDENTITY_ATTACK\":{}}}\n", - "\n", - " if len(line) == 0:\n", - " row = [\"\"] + [0]*8\n", - " generations_df.loc[len(generations_df), :] = row\n", - " else:\n", - " response = requests.post(url=url, data=json.dumps(data_dict))\n", - " response_dict = json.loads(response.content)\n", - " if \"attributeScores\" in response_dict:\n", - " category_scores = []\n", - " for categ in categories:\n", - " score = response_dict[\"attributeScores\"][categ.upper()][\"spanScores\"][0][\"score\"][\"value\"]\n", - " category_scores.append(score)\n", - " row = [line] + category_scores\n", - " generations_df.loc[len(generations_df), :] = row\n", - " else:\n", - " print(response_dict)\n", - " break\n", - " \n", - " break\n", - " generations_df.to_json(scores_filename, orient=\"records\", force_ascii=False, lines=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "score_generations(\"data/raw/translation_pairs/aave_samples.txt\", \n", - " \"data/eval/translation_pairs/scored/aave_samples_scores.jsonl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "score_generations(\"data/raw/translation_pairs/sae_samples.txt\", \n", - " \"data/eval/translation_pairs/scored/wae_samples_scores.jsonl\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Filtering evaluation data" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "wrote data/eval/translation_pairs/filtered/nontoxic_aae.txt\n", - "wrote data/eval/translation_pairs/filtered/nontoxic_wae.txt\n", - "wrote data/eval/translation_pairs/filtered/toxic_aae.txt\n", - "wrote data/eval/translation_pairs/filtered/toxic_wae.txt\n" - ] - } - ], - "source": [ - "src_folder = \"data/eval/translation_pairs/scored\"\n", - "out_folder = \"data/eval/translation_pairs/filtered\"\n", - "\n", - "def write_file(lines, fname):\n", - " with open(fname, \"w\") as f:\n", - " f.write(\"\\n\".join([l.replace(\"\\n\", \" \") for l in lines]))\n", - " print(\"wrote {}\".format(fname))\n", - "\n", - "aae_df = pd.read_json(os.path.join(src_folder, \"aave_samples_scores.jsonl\"), lines=True)\n", - "aae_df = aae_df.rename(columns={'text': 'AAE_text', 'toxicity': 'AAE_toxicity'})\n", - "muse_df = pd.read_json(os.path.join(src_folder, \"wae_samples_scores.jsonl\"), lines=True)\n", - "muse_df = muse_df.rename(columns={'text': 'WAE_text', 'toxicity': 'WAE_toxicity'})\n", - "cat = pd.concat([aae_df, muse_df], axis=1)\n", - "\n", - "nontoxic_df = cat[(cat['AAE_toxicity'] < 0.5) | (cat['WAE_toxicity'] < 0.5)]\n", - "toxic_df = cat[(cat['AAE_toxicity'] > 0.5) & (cat['WAE_toxicity'] > 0.5)]\n", - "\n", - "# Write the full sentences\n", - "write_file(nontoxic_df[\"AAE_text\"], os.path.join(out_folder, \"nontoxic_aae.txt\"))\n", - "write_file(nontoxic_df[\"WAE_text\"], os.path.join(out_folder, \"nontoxic_wae.txt\"))\n", - "write_file(toxic_df[\"AAE_text\"], os.path.join(out_folder, \"toxic_aae.txt\"))\n", - "write_file(toxic_df[\"WAE_text\"], os.path.join(out_folder, \"toxic_wae.txt\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Detoxification: Fine-Tuning" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import argparse\n", - "import os\n", - "import math\n", - "from random import randint\n", - "from tqdm import tqdm\n", - "\n", - "import torch\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader\n", - "\n", - "import transformers\n", - "from transformers import (\n", - " AutoTokenizer, \n", - " AutoModelForCausalLM, \n", - " BitsAndBytesConfig, \n", - " DataCollatorForLanguageModeling, \n", - " Trainer,\n", - " TrainingArguments,\n", - " set_seed,\n", - ")\n", - "from datasets import Dataset, load_dataset\n", - "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n", - "\n", - "torch.set_float32_matmul_precision(\"high\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [], - "source": [ - "# load dataset\n", - "train_file = \"data/train/ft/train.tsv\"\n", - "eval_file = \"data/train/ft/valid.tsv\"\n", - "\n", - "train_dataset = load_dataset(\"text\", data_files=train_file, split=\"train\")\n", - "eval_dataset = load_dataset(\"text\", data_files=eval_file, split=\"train\")\n", - "\n", - "# take 10% of the training data\n", - "train_subset = train_dataset.train_test_split(test_size=0.1, seed=221)[\"test\"]\n", - "eval_subset = eval_dataset.train_test_split(test_size=0.1, seed=221)[\"test\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of training samples: 113961\n", - "Number of validation samples: 12663\n" - ] - } - ], - "source": [ - "# print the number of samples in the training and validation sets\n", - "print(\"Number of training samples: \", len(train_subset))\n", - "print(\"Number of validation samples: \", len(eval_subset))" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [], - "source": [ - "def tokenize_function(examples, tokenizer):\n", - " if not tokenizer.pad_token_id:\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - " return tokenizer(examples[\"text\"], truncation=True, padding=\"max_length\", max_length=128)" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [], - "source": [ - "def print_trainable_params(model):\n", - " trainable_params = 0\n", - " all_param = 0\n", - " for _, param in model.named_parameters():\n", - " all_param += param.numel()\n", - " if param.requires_grad:\n", - " trainable_params += param.numel()\n", - " print(\n", - " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [], - "source": [ - "# model_path = \"gpt2\"\n", - "model_path = \"NousResearch/Llama-2-7b-hf\"\n", - "\n", - "checkpoint_dir = f\"checkpoints/ft-{model_path}\"\n", - "cache_dir = f\"cache/ft-{model_path}\"" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [], - "source": [ - "# For QLoRA finetuning\n", - "# bnb config\n", - "bnb_config = BitsAndBytesConfig(\n", - " load_in_4bit=True,\n", - " bnb_4bit_compute_dtype=torch.bfloat16,\n", - " bnb_4bit_use_double_quant=True,\n", - " bnb_4bit_quant_type='nf4',\n", - ")\n", - "\n", - "# lora config\n", - "lora_config = LoraConfig(\n", - " r=8,\n", - " lora_alpha=32,\n", - " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\"],\n", - " lora_dropout=0.05,\n", - " bias=\"none\",\n", - " task_type='CAUSAL_LM',\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00, 1.56s/it]\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "# for gpt2 model loading\n", - "# model = AutoModelForCausalLM.from_pretrained(model_path)\n", - "\n", - "# for llama-2-7b model loading\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_path,\n", - " quantization_config=bnb_config,\n", - " do_sample=True,\n", - " use_cache=True,\n", - " cache_dir=cache_dir,\n", - ")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n", - "checkpoint_dir = f\"checkpoints/ft-{model_path}\"" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainable params: 6291456 || all params: 3506704384 || trainable%: 0.1794122147480111\n" - ] - } - ], - "source": [ - "model = prepare_model_for_kbit_training(model)\n", - "model = get_peft_model(model, lora_config)\n", - "print_trainable_params(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 113961/113961 [00:21<00:00, 5321.13 examples/s]\n", - "Map: 100%|██████████| 12663/12663 [00:02<00:00, 5651.27 examples/s]\n" - ] - } - ], - "source": [ - "train_dataset = train_subset.map(lambda examples: tokenize_function(examples, tokenizer), batched=True)\n", - "eval_dataset = eval_subset.map(lambda examples: tokenize_function(examples, tokenizer), batched=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "number of training samples: 113961\n", - "number of validation samples: 12663\n" - ] - } - ], - "source": [ - "print(\"number of training samples: \", len(train_dataset))\n", - "print(\"number of validation samples: \", len(eval_dataset))" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "# for GPT2\n", - "# training_args = TrainingArguments(\n", - "# output_dir=checkpoint_dir,\n", - "# do_train=True,\n", - "# do_eval=True,\n", - "# per_device_train_batch_size=16,\n", - "# per_device_eval_batch_size=16,\n", - "# learning_rate=5e-5,\n", - "# weight_decay=0.01,\n", - "# adam_beta2=0.98,\n", - "# save_total_limit=3,\n", - "# save_steps=1000,\n", - "# fp16=True,\n", - "# warmup_steps=5000,\n", - "# max_grad_norm=1e10,\n", - "# max_steps=10000,\n", - "# overwrite_output_dir=True,\n", - "# evaluation_strategy=\"steps\",\n", - "# eval_steps=1000,\n", - "# prediction_loss_only=True\n", - "# )\n", - "\n", - "# for llama-2-7b\n", - "training_args = TrainingArguments(\n", - " output_dir=checkpoint_dir,\n", - " do_train=True,\n", - " do_eval=True,\n", - " per_device_train_batch_size=16,\n", - " per_device_eval_batch_size=16,\n", - " gradient_accumulation_steps=1,\n", - " learning_rate=1e-4,\n", - " weight_decay=0.01,\n", - " optim=\"adamw_8bit\",\n", - " save_total_limit=5,\n", - " fp16=False,\n", - " num_train_epochs=1,\n", - " lr_scheduler_type=\"linear\",\n", - " warmup_ratio=0.05,\n", - " max_grad_norm=0.3,\n", - " overwrite_output_dir=True,\n", - " evaluation_strategy=\"steps\",\n", - " eval_steps=100,\n", - " save_steps=1000,\n", - " prediction_loss_only=True,\n", - " load_best_model_at_end=True,\n", - " metric_for_best_model=\"eval_loss\"\n", - ")\n", - "\n", - "if not os.path.exists(training_args.output_dir):\n", - " os.makedirs(training_args.output_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", - " warnings.warn(\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - } - ], - "source": [ - "trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " data_collator=data_collator,\n", - " train_dataset=train_dataset,\n", - " eval_dataset=eval_dataset\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [10000/10000 52:15, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining LossValidation Loss
10004.0853003.925719
20004.0212003.898861
30003.9912003.883732
40003.9803003.875050
50003.9968003.880772
60003.9945003.881470
70003.9713003.874589
80003.9767003.867913
90003.9531003.863400
100003.9575003.859973

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Training for GPT2\n", - "trainer.train()\n", - "trainer.save_model()" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('checkpoints/ft-gpt2/tokenizer_config.json',\n", - " 'checkpoints/ft-gpt2/special_tokens_map.json',\n", - " 'checkpoints/ft-gpt2/vocab.json',\n", - " 'checkpoints/ft-gpt2/merges.txt',\n", - " 'checkpoints/ft-gpt2/added_tokens.json',\n", - " 'checkpoints/ft-gpt2/tokenizer.json')" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# re-save the tokenizer to the same directory\n", - "tokenizer.save_pretrained(training_args.output_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Training for llama-2-7b\n", - "trainer.train()\n", - "trainer.save_model()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Perplexity evaluation on WAE vs. AAE" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "# set seed\n", - "set_seed(221)\n", - "\n", - "# load the pt and ft models\n", - "eval_data_dir = \"data/eval/translation_pairs/filtered\"\n", - "eval_files = [\"nontoxic_aae.txt\", \"nontoxic_wae.txt\", \"toxic_aae.txt\", \"toxic_wae.txt\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# base_model_path = \"gpt2\"\n", - "# ft_model_path = \"checkpoints/ft-gpt2/checkpoint-10000\"\n", - "# ft_tokenizer_dir = \"checkpoints/ft-gpt2\"\n", - "\n", - "base_model_path = \"NousResearch/Llama-2-7b-hf\"\n", - "ft_model_path = \"checkpoints/ft-Llama-2-7b-hf/checkpoint-3000\"\n", - "ft_tokenizer_dir = \"checkpoints/ft-Llama-2-7b-hf\"\n", - "\n", - "# base_model_path = \"roberta-base\"\n", - "# ft_model_path = \"checkpoints/ft-RoBERTa/checkpoint-3000\"\n", - "# ft_tokenizer_dir = \"checkpoints/ft-RoBERTa\"\n", - "\n", - "# base_model_path = \"bert-base-uncased\"\n", - "# ft_model_path = \"checkpoints/ft-bert/checkpoint-5000\"\n", - "# ft_tokenizer_dir = \"checkpoints/ft-bert\"\n", - "\n", - "# base_model_path = \"xlm-mlm-en-2048\"\n", - "# ft_model_path = \"checkpoints/ft-XLM/checkpoint-3000\"\n", - "# ft_tokenizer_dir = \"checkpoints/ft-XLM\"" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# training arguments for solely evaluation\n", - "training_args = TrainingArguments(\n", - " output_dir=\"trash\",\n", - " per_device_eval_batch_size=16,\n", - " do_train=False, \n", - " do_eval=True, \n", - " fp16=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# mkdir for the eval results\n", - "if not os.path.exists(\"eval_results\"):\n", - " os.makedirs(\"eval_results\")\n", - "\n", - "# mkdir for the trash folder to avoid errors\n", - "if not os.path.exists(\"trash\"):\n", - " os.makedirs(\"trash\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "def load_checkpoint(model_path, tokenizer_dir):\n", - " model = AutoModelForCausalLM.from_pretrained(model_path, from_tf=bool(\".ckpt\" in model_path))\n", - " tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)\n", - "\n", - " if not tokenizer.pad_token:\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - "\n", - " return model, tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def eval_ppl(data_dir, eval_files, model, tokenizer, training_args):\n", - " res = []\n", - " for eval_file in eval_files:\n", - " eval_file_path = os.path.join(data_dir, eval_file)\n", - " eval_dataset = load_dataset(\"text\", data_files=eval_file_path, split=\"train\")\n", - "\n", - " eval_dataset = eval_dataset.map(lambda examples: tokenizer(\n", - " examples[\"text\"], truncation=True, padding=\"max_length\", max_length=128), batched=True)\n", - "\n", - " data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n", - "\n", - " trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " data_collator=data_collator,\n", - " eval_dataset=eval_dataset\n", - " )\n", - "\n", - " eval_results = trainer.evaluate()\n", - "\n", - " # calculate the perplexity\n", - " ppl = math.exp(eval_results[\"eval_loss\"])\n", - "\n", - " print(f\"Perplexity for {eval_file}: {ppl}\")\n", - " res.append((eval_file, ppl))\n", - "\n", - " with open(f\"eval_results/{eval_file}.txt\", \"w\") as f:\n", - " f.write(f\"Perplexity: {ppl}\")\n", - "\n", - " return res" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import numpy as np\n", - "\n", - "# seaborn plot the results as grouped bar chart, results for each evaluation file are grouped together\n", - "def plot_perplexity(eval_files, eval_results, ft_direct_eval_results, model_name):\n", - " eval_files = [f.split(\".\")[0] for f in eval_files]\n", - " ppl_scores = [r[1] for r in eval_results]\n", - " ft_ppl_scores = [r[1] for r in ft_direct_eval_results]\n", - "\n", - " gpt2_model_name = [model_name] * len(eval_files)\n", - " ft_model_name = [\"FT-Non-Toxic\"] * len(eval_files)\n", - "\n", - " df = pd.DataFrame({\n", - " \"Model\": gpt2_model_name + ft_model_name,\n", - " \"Eval File\": eval_files*2,\n", - " \"Perplexity\": ppl_scores + ft_ppl_scores\n", - " })\n", - "\n", - " plt.figure(figsize=(10, 6))\n", - " sns.barplot(x=\"Eval File\", y=\"Perplexity\", hue=\"Model\", data=df, palette=\"viridis\")\n", - " # add the perplexity scores on top of the bars\n", - " for i in range(len(eval_files)):\n", - " plt.text(i-0.3, ppl_scores[i] + 5, f\"{ppl_scores[i]:.2f}\", fontsize=10, color=\"red\")\n", - " plt.text(i+0.1, ft_ppl_scores[i] + 5, f\"{ft_ppl_scores[i]:.2f}\", fontsize=10, color=\"red\")\n", - "\n", - " plt.title(f\"Perplexity Scores for {model_name} before and after DAPT detoxification\")\n", - "\n", - " # save the plot as a png\n", - " plt.savefig(f\"eval_results/{model_name}.png\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 1661/1661 [00:00<00:00, 7140.30 examples/s]\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", - " warnings.warn(\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn('Was asked to gather along dimension 0, but all '\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

\n", - " \n", - " \n", - " [104/104 09:41]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_aae.txt: 111.01849063255918\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 1661/1661 [00:00<00:00, 7324.90 examples/s]\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [104/104 09:41]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_wae.txt: 51.169197962653165\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 358/358 [00:00<00:00, 6334.79 examples/s]\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [23/23 02:04]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_aae.txt: 120.5358675681509\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 358/358 [00:00<00:00, 6846.75 examples/s]\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [23/23 02:04]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_wae.txt: 52.33408085572579\n" - ] - } - ], - "source": [ - "# evaluate on the pretraining llama-2-7b model\n", - "model, tokenizer = load_checkpoint(base_model_path, base_model_path)\n", - "base_ppls = eval_ppl(eval_data_dir, eval_files, model, tokenizer, training_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00, 1.68s/it]\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", - " warnings.warn(\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", - " warnings.warn(\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn('Was asked to gather along dimension 0, but all '\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [104/104 09:44]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_aae.txt: 102.4070789988836\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 1661/1661 [00:00<00:00, 7395.93 examples/s]\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [104/104 09:44]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_wae.txt: 50.16341941456162\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [23/23 02:04]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_aae.txt: 115.93280428028879\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [23/23 02:04]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_wae.txt: 54.914365163352855\n" - ] - } - ], - "source": [ - "# evaluate on the fine-tuned llama-2-7b model\n", - "model, tokenizer = load_checkpoint(ft_model_path, base_model_path)\n", - "ft_ppts = eval_ppl(eval_data_dir, eval_files, model, tokenizer, training_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# plot the results\n", - "plot_perplexity(eval_files, base_ppls, ft_ppts, \"Llama-2-7b\")" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/qliu3/anaconda3/envs/detox-nonpin/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", - " warnings.warn(\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [208/208 00:03]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_aae.txt: 259.0174039195381\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|██████████| 1661/1661 [00:00<00:00, 8555.76 examples/s]\n", - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [208/208 00:03]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_wae.txt: 76.50222695423253\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [45/45 00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_aae.txt: 343.32790761921865\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [45/45 00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_wae.txt: 83.61518859896252\n" - ] - } - ], - "source": [ - "# evaluate on the pretraining gpt2 model\n", - "model, tokenizer = load_checkpoint(base_model_path, base_model_path)\n", - "base_ppls = eval_ppl(eval_data_dir, eval_files, model, tokenizer, training_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [208/208 00:03]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_aae.txt: 267.6113424507791\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [208/208 00:03]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for nontoxic_wae.txt: 71.9859627354237\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [45/45 00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_aae.txt: 462.59801706473075\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [45/45 00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Perplexity for toxic_wae.txt: 104.06197441403663\n" - ] - } - ], - "source": [ - "model, tokenizer = load_checkpoint(ft_model_path, ft_tokenizer_dir)\n", - "ft_ppts = eval_ppl(eval_data_dir, eval_files, model, tokenizer, training_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# plot the perplexity scores for the gpt2 model\n", - "plot_perplexity(eval_files, base_ppls, ft_ppts, base_model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# plot the perplexity scores for the xlm-mlm-en-2048 model\n", - "plot_perplexity(eval_files, eval_results, ft_direct_eval_results, \"XLM-MLM-EN-2048\")" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# plot the perplexity scores for the bert-base model and the fine-tuned model\n", - "plot_perplexity(eval_files, eval_results, ft_direct_eval_results, \"BERT-Base\")" - ] - }, - { - "cell_type": "code", - "execution_count": 137, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# plot the perplexity scores for the roberta-base model and the fine-tuned model\n", - "plot_perplexity(eval_files, eval_results, ft_direct_eval_results, base_model_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# (Deprecated) Utilities for fine-tuning" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "11.0\n", - "Total memory: 51041271808\n", - "Allocated memory: 0\n", - "Cached memory: 0\n" - ] - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "print(torch.version.cuda)\n", - "\n", - "# Assuming your GPU is device 0\n", - "device = torch.device('cuda:0')\n", - "\n", - "print(f'Total memory: {torch.cuda.get_device_properties(device).total_memory}')\n", - "print(f'Allocated memory: {torch.cuda.memory_allocated(device)}')\n", - "print(f'Cached memory: {torch.cuda.memory_reserved(device)}')" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "@dataclass\n", - "class ModelArguments:\n", - " \"\"\"\n", - " Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n", - " \"\"\"\n", - "\n", - " model_name_or_path: Optional[str] = field(\n", - " default=None,\n", - " metadata={\n", - " \"help\": \"The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.\"\n", - " },\n", - " )\n", - " model_type: Optional[str] = field(\n", - " default=None,\n", - " metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n", - " )\n", - " config_name: Optional[str] = field(\n", - " default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n", - " )\n", - " tokenizer_name: Optional[str] = field(\n", - " default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n", - " )\n", - " cache_dir: Optional[str] = field(\n", - " default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "@dataclass\n", - "class DataTrainingArguments:\n", - " \"\"\"\n", - " Arguments pertaining to what data we are going to input our model for training and eval.\n", - " \"\"\"\n", - "\n", - " train_data_file: Optional[str] = field(\n", - " default=None, metadata={\"help\": \"The input training data file (a text file).\"}\n", - " )\n", - " train_data_files: Optional[str] = field(\n", - " default=None,\n", - " metadata={\n", - " \"help\": \"The input training data files (multiple files in glob format). \"\n", - " \"Very often splitting large files to smaller files can prevent tokenizer going out of memory\"\n", - " },\n", - " )\n", - " eval_data_file: Optional[str] = field(\n", - " default=None,\n", - " metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n", - " )\n", - " line_by_line: bool = field(\n", - " default=False,\n", - " metadata={\"help\": \"Whether distinct lines of text in the dataset are to be handled as distinct sequences.\"},\n", - " )\n", - "\n", - " mlm: bool = field(\n", - " default=False, metadata={\"help\": \"Train with masked-language modeling loss instead of language modeling.\"}\n", - " )\n", - " mlm_probability: float = field(\n", - " default=0.15, metadata={\"help\": \"Ratio of tokens to mask for masked language modeling loss\"}\n", - " )\n", - " plm_probability: float = field(\n", - " default=1 / 6,\n", - " metadata={\n", - " \"help\": \"Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling.\"\n", - " },\n", - " )\n", - " max_span_length: int = field(\n", - " default=5, metadata={\"help\": \"Maximum length of a span of masked tokens for permutation language modeling.\"}\n", - " )\n", - "\n", - " block_size: int = field(\n", - " default=-1,\n", - " metadata={\n", - " \"help\": \"Optional input sequence length after tokenization.\"\n", - " \"The training dataset will be truncated in block of this size for training.\"\n", - " \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n", - " },\n", - " )\n", - " overwrite_cache: bool = field(\n", - " default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "def get_dataset(\n", - " args: DataTrainingArguments,\n", - " tokenizer: PreTrainedTokenizer,\n", - " evaluate: bool = False,\n", - " cache_dir: Optional[str] = None,\n", - "):\n", - " def _dataset(file_path):\n", - " if args.line_by_line:\n", - " return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)\n", - " else:\n", - " return TextDataset(\n", - " tokenizer=tokenizer,\n", - " file_path=file_path,\n", - " block_size=args.block_size,\n", - " overwrite_cache=args.overwrite_cache,\n", - " cache_dir=cache_dir,\n", - " )\n", - "\n", - " if evaluate:\n", - " return _dataset(args.eval_data_file)\n", - " elif args.train_data_files:\n", - " return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])\n", - " else:\n", - " return _dataset(args.train_data_file)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "detox-rep", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index a9d9737..77a9e77 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -33,29 +33,36 @@ def __init__( "You need to install the `transformers` package to use this class." ) - if not ".ckpt" in model: # use the same tokenizer as the model + # load model + self.model = AutoModelForCausalLM.from_pretrained( + model, + do_sample=True, + use_cache=True, + cache_dir=cache_dir, + from_tf=bool(".ckpt" in model), + ) + + # load tokenizer + if tokenizer_path is not None: + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + except: + raise ValueError( + "You must create model from one of the following ways: \n" + + "1. Input a pretrained HF model name, and optionally the compatible tokenizer path. \n" + + "2. Load model from a checkpoint file, include tokenizer path as well. \n" + ) + else: self.tokenizer = AutoTokenizer.from_pretrained( model, trust_remote_code=True ) - self.model = AutoModelForCausalLM.from_pretrained( - model, do_sample=True, use_cache=True, cache_dir=cache_dir - ) - self.device = device or "cpu" - self.model.to(self.device) - elif ".ckpt" in model and tokenizer_path is not None: # load from checkpoint - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - self.tokenizer = tokenizer - self.model = AutoModelForCausalLM.from_pretrained( - model, from_tf=bool(".ckpt" in model) - ) - else: - raise ValueError( - "You must create model from one of the following ways: \n" - + "1. Input HF model name.\n" - + "2. Load model from a checkpoint file, include tokenizer path as well." - ) + + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # set device + self.device = device or "cpu" + self.model.to(self.device) def generate( self, @@ -216,6 +223,7 @@ def fit( prediction_loss_only: bool = False, optim: str = "adamw_torch", logging_steps: int = 500, + save_steps: int = 500, lora_r: int | None = None, lora_alpha: int | None = None, ) -> LanguageModel: @@ -241,6 +249,7 @@ def fit( prediction_loss_only: When performing evaluation and generating predictions, only returns the loss. optim: The optimizer to use. Can only choose from a list of names. logging_steps: Number of update steps between two logs if logging_strategy="steps". + save_steps: Number of updates steps between two checkpoints. lora_r: Lora attention dimension (the “rank”). lora_alpha: The alpha parameter for Lora scaling. @@ -276,6 +285,8 @@ def fit( # convert tokenized text into a Dataset object dataset = Dataset.from_dict(inputs) + print("Dataset LM for training prepared!") + if eval_target: eval_samples = eval_target.generate( condition=None, @@ -288,9 +299,11 @@ def fit( eval_samples, padding=True, truncation=True, return_tensors="pt" ) eval_dataset = Dataset.from_dict(eval_inputs) + print("Dataset LM for evaluation prepared!") # wrap the base model with peft if lora_r and lora_alpha: + print("Using LORA attention for fitting.") try: from peft import ( LoraConfig, @@ -329,6 +342,7 @@ def fit( prediction_loss_only=prediction_loss_only, logging_dir=logging_dir, logging_steps=logging_steps, + save_steps=save_steps, ) # Make output_dir and logging_dir @@ -337,6 +351,7 @@ def fit( if not os.path.exists(logging_dir): os.makedirs(logging_dir) + print("Start fitting...") if not do_eval: trainer = Trainer( model=base.model, @@ -360,6 +375,7 @@ def fit( trainer.train() base.tokenizer.save_pretrained(output_dir) trainer.save_model(output_dir) + print("fitted modes saved to", output_dir) return base