From 38ae875db07f6135ef472b17bfa7fa7dfc620215 Mon Sep 17 00:00:00 2001 From: Qingyang Liu Date: Thu, 10 Oct 2024 19:22:15 -0400 Subject: [PATCH] Add probability/perplexity calculation and complete detoxified LM notebook --- .../Detoxify_LM_demo.ipynb | 68 ++------ llments/lm/base/hugging_face.py | 147 +++++++++++++++++- 2 files changed, 161 insertions(+), 54 deletions(-) diff --git a/examples/detoxification_bias/Detoxify_LM_demo.ipynb b/examples/detoxification_bias/Detoxify_LM_demo.ipynb index 198fa8a..e984ec0 100644 --- a/examples/detoxification_bias/Detoxify_LM_demo.ipynb +++ b/examples/detoxification_bias/Detoxify_LM_demo.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -115,18 +115,9 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "from llments.eval.toxicity import ToxicityEvaluator\n", "# create a toxicity evaluator for text scoring\n", @@ -366,48 +357,29 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Perplexity evaluation on WAE vs. AAE\n", - "\n", - "### This part still awaits replacement with llment code" + "# Perplexity evaluation on WAE vs. AAE" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ - "from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling\n", - "from datasets import load_dataset\n", - "import math\n", - "\n", - "# helper function to evaluate the perplexity of a fine-tuned model\n", - "def eval_ppl(data_dir, eval_files, model_path, model_name, training_args, output_dir):\n", + "# helper function to evaluate the perplexity of a detoxified model on a given dataset\n", + "def eval_ppl(data_dir, eval_files, model_path, model_name, output_dir):\n", " \n", - " checkpoint = HuggingFaceLM(model=model_path, tokenizer_path=model_name)\n", + " checkpointLM = HuggingFaceLM(model=model_path, tokenizer_path=model_name)\n", " res = []\n", "\n", - " # TODO: replace the following with llments operators if available, \n", - " # consider a HGLM.evaluate() operator\n", " for eval_file in eval_files:\n", " eval_file_path = os.path.join(data_dir, eval_file)\n", - " eval_dataset = load_dataset(\"text\", data_files=eval_file_path, split=\"train\")\n", " \n", - " eval_dataset = eval_dataset.map(lambda examples: checkpoint.tokenizer(\n", - " examples[\"text\"], truncation=True, padding=\"max_length\", max_length=128), batched=True)\n", - "\n", - " data_collator = DataCollatorForLanguageModeling(tokenizer=checkpoint.tokenizer, mlm=False)\n", - "\n", - " trainer = Trainer(\n", - " model=checkpoint.model,\n", - " args=training_args,\n", - " data_collator=data_collator,\n", - " eval_dataset=eval_dataset\n", - " )\n", - " eval_results = trainer.evaluate()\n", - "\n", - " # calculate the perplexity\n", - " ppl = math.exp(eval_results[\"eval_loss\"])\n", + " with open(eval_file_path, \"r\") as f:\n", + " sentences = f.readlines()\n", + " \n", + " # calculate the average perplexity of the samples\n", + " ppl = checkpointLM.calculate_perplexity_batch(outputs=sentences, condition=None)\n", " res.append((eval_file, ppl))\n", "\n", " with open(f\"{output_dir}/{model_path.split('/')[-1]}.txt\", \"w\") as f:\n", @@ -419,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -485,14 +457,6 @@ "if not os.path.exists(\"trash\"):\n", " os.makedirs(\"trash\")\n", "\n", - "training_args = TrainingArguments(\n", - " output_dir=\"trash\",\n", - " per_device_eval_batch_size=1,\n", - " do_train=False, \n", - " do_eval=True, \n", - " fp16=False,\n", - ")\n", - "\n", "model_path_list = [\"gpt2\",\n", " \"checkpoints/gpt2/checkpoint-2500\",\n", " \"checkpoints/gpt2/checkpoint-5000\",\n", @@ -505,7 +469,7 @@ " \"checkpoints/gpt2/checkpoint-22500\"]\n", "\n", "for model_path in model_path_list:\n", - " res = eval_ppl(eval_data_dir, eval_files, model_path, \"gpt2\", training_args, \"eval_results_gpt2\")\n", + " res = eval_ppl(eval_data_dir, eval_files, model_path, \"gpt2\", \"eval_results_gpt2\")\n", " print(res)" ] }, @@ -553,7 +517,7 @@ "\n", "for model_path in model_path_list:\n", " res = eval_ppl(eval_data_dir, eval_files, model_path, \n", - " \"NousResearch/Llama-2-7b-hf\", training_args, \"eval_results_Llama2-7b\")\n", + " \"NousResearch/Llama-2-7b-hf\", \"eval_results_Llama2-7b\")\n", " print(res)" ] }, diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index 77a9e77..1d4ef40 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -184,7 +184,7 @@ def set_seed(self, seed: int) -> None: set_seed(seed) def calculate_probability(self, condition: str | None, output: str) -> float: - """Calculate the probability of an output given the language model. + """Calculate the log probability of an output given the language model. Args: condition: The conditioning sequence for the output. @@ -193,7 +193,150 @@ def calculate_probability(self, condition: str | None, output: str) -> float: Returns: float: The probability of output x given the language model. """ - raise NotImplementedError + try: + import numpy as np + except ImportError: + raise ImportError( + "You need to install 'numpy' package to use this function." + ) + + if condition: + full_input = condition + " " + output + else: + full_input = output + + # Tokenize the full input (condition + output or just output) + inputs = self.tokenizer( + full_input, + return_tensors="pt", + truncation=True, + padding=False, # Avoid padding unless needed + ) + + # Get model outputs (logits) + full_outputs = self.model(**inputs, return_dict=True) + logits = ( + full_outputs.logits.detach().cpu().numpy() + ) # Convert logits to NumPy array + full_input_ids = inputs["input_ids"][0].cpu().numpy() + + # define a softmax function + def softmax(logits: np.ndarray) -> np.ndarray: + exps = np.exp( + logits - np.max(logits, axis=-1, keepdims=True) + ) # Stabilize softmax + return exps / np.sum(exps, axis=-1, keepdims=True) + + # Calculate the probability of the output + probs = softmax(logits[0]) # Only one sequence in the batch + probs = probs[:-1, :] + + # calculate the num of tokens corresponding to the output + output_ids = self.tokenizer(output)["input_ids"] + output_ids = output_ids[1:] + full_input_ids = full_input_ids[1:] + start_idx = len(full_input_ids) - len(output_ids) + + # take the last # of output_tokens from the log_probs + log_probs = np.log(probs[np.arange(start_idx, len(full_input_ids)), output_ids]) + + # convert the log_probs to a float + return float(np.sum(log_probs)) + + def calculate_perplexity(self, condition: str | None, output: str) -> float: + """Calculate the perplexity of an output given the language model. + + Args: + condition: The conditioning sequence for the output. + output: The output sequence for which the probability is calculated. + + Returns: + float: The perplexity of output x given the language model. + """ + try: + import numpy as np + except ImportError: + raise ImportError( + "You need to install 'numpy' package to use this function." + ) + + log_prob = self.calculate_probability(condition, output) + num_tokens = len(self.tokenizer(output)["input_ids"]) - 1 + + return float(np.exp(-log_prob / num_tokens)) + + def calculate_perplexity_batch( + self, condition: list[str] | None, outputs: list[str] + ) -> float: + """Calculate the perplexity of multiple outputs given the language model. + + Args: + condition: The conditioning sequence for the output. + outputs: The output sequences for which the probability is calculated. + + Returns: + list[float]: The perplexity of outputs given the language model. + """ + if condition: + full_inputs = [c + " " + o for c, o in zip(condition, outputs)] + else: + full_inputs = outputs + + # check if the user have import Trainer, TrainingArguments, DataCollatorForLanguageModeling + try: + from datasets import Dataset + from transformers import ( + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, + ) + import numpy as np + except ImportError: + print( + "Naive implementation is used. This may harm the efficiency of the calculation." + ) + try: + import numpy as np + except ImportError: + raise ImportError( + "You need to install 'numpy' package to use this function." + ) + return float( + np.mean(self.calculate_perplexity(None, o) for o in full_inputs) + ) + + # prepare the dataset + inputs = self.tokenizer( + full_inputs, + return_tensors="pt", + truncation=True, + padding=True, + ) + + dataset = Dataset.from_dict(inputs) + data_collator = DataCollatorForLanguageModeling( + tokenizer=self.tokenizer, mlm=False + ) + + training_arguments = TrainingArguments( + output_dir="trash", + per_device_eval_batch_size=1, + do_train=False, + do_eval=True, + fp16=False, + prediction_loss_only=True, + ) + + trainer = Trainer( + model=self.model, + data_collator=data_collator, + args=training_arguments, + eval_dataset=dataset, + ) + + eval_result = trainer.evaluate() + + return float(np.exp(eval_result["eval_loss"])) class HuggingFaceLMFitter: