Skip to content

Commit

Permalink
Add probability/perplexity calculation and complete detoxified LM not…
Browse files Browse the repository at this point in the history
…ebook
  • Loading branch information
qingyangliu0065 committed Oct 10, 2024
1 parent 0cf2c9a commit 38ae875
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 54 deletions.
68 changes: 16 additions & 52 deletions examples/detoxification_bias/Detoxify_LM_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -115,18 +115,9 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"from llments.eval.toxicity import ToxicityEvaluator\n",
"# create a toxicity evaluator for text scoring\n",
Expand Down Expand Up @@ -366,48 +357,29 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Perplexity evaluation on WAE vs. AAE\n",
"\n",
"### This part still awaits replacement with llment code"
"# Perplexity evaluation on WAE vs. AAE"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling\n",
"from datasets import load_dataset\n",
"import math\n",
"\n",
"# helper function to evaluate the perplexity of a fine-tuned model\n",
"def eval_ppl(data_dir, eval_files, model_path, model_name, training_args, output_dir):\n",
"# helper function to evaluate the perplexity of a detoxified model on a given dataset\n",
"def eval_ppl(data_dir, eval_files, model_path, model_name, output_dir):\n",
" \n",
" checkpoint = HuggingFaceLM(model=model_path, tokenizer_path=model_name)\n",
" checkpointLM = HuggingFaceLM(model=model_path, tokenizer_path=model_name)\n",
" res = []\n",
"\n",
" # TODO: replace the following with llments operators if available, \n",
" # consider a HGLM.evaluate() operator\n",
" for eval_file in eval_files:\n",
" eval_file_path = os.path.join(data_dir, eval_file)\n",
" eval_dataset = load_dataset(\"text\", data_files=eval_file_path, split=\"train\")\n",
" \n",
" eval_dataset = eval_dataset.map(lambda examples: checkpoint.tokenizer(\n",
" examples[\"text\"], truncation=True, padding=\"max_length\", max_length=128), batched=True)\n",
"\n",
" data_collator = DataCollatorForLanguageModeling(tokenizer=checkpoint.tokenizer, mlm=False)\n",
"\n",
" trainer = Trainer(\n",
" model=checkpoint.model,\n",
" args=training_args,\n",
" data_collator=data_collator,\n",
" eval_dataset=eval_dataset\n",
" )\n",
" eval_results = trainer.evaluate()\n",
"\n",
" # calculate the perplexity\n",
" ppl = math.exp(eval_results[\"eval_loss\"])\n",
" with open(eval_file_path, \"r\") as f:\n",
" sentences = f.readlines()\n",
" \n",
" # calculate the average perplexity of the samples\n",
" ppl = checkpointLM.calculate_perplexity_batch(outputs=sentences, condition=None)\n",
" res.append((eval_file, ppl))\n",
"\n",
" with open(f\"{output_dir}/{model_path.split('/')[-1]}.txt\", \"w\") as f:\n",
Expand All @@ -419,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -485,14 +457,6 @@
"if not os.path.exists(\"trash\"):\n",
" os.makedirs(\"trash\")\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"trash\",\n",
" per_device_eval_batch_size=1,\n",
" do_train=False, \n",
" do_eval=True, \n",
" fp16=False,\n",
")\n",
"\n",
"model_path_list = [\"gpt2\",\n",
" \"checkpoints/gpt2/checkpoint-2500\",\n",
" \"checkpoints/gpt2/checkpoint-5000\",\n",
Expand All @@ -505,7 +469,7 @@
" \"checkpoints/gpt2/checkpoint-22500\"]\n",
"\n",
"for model_path in model_path_list:\n",
" res = eval_ppl(eval_data_dir, eval_files, model_path, \"gpt2\", training_args, \"eval_results_gpt2\")\n",
" res = eval_ppl(eval_data_dir, eval_files, model_path, \"gpt2\", \"eval_results_gpt2\")\n",
" print(res)"
]
},
Expand Down Expand Up @@ -553,7 +517,7 @@
"\n",
"for model_path in model_path_list:\n",
" res = eval_ppl(eval_data_dir, eval_files, model_path, \n",
" \"NousResearch/Llama-2-7b-hf\", training_args, \"eval_results_Llama2-7b\")\n",
" \"NousResearch/Llama-2-7b-hf\", \"eval_results_Llama2-7b\")\n",
" print(res)"
]
},
Expand Down
147 changes: 145 additions & 2 deletions llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def set_seed(self, seed: int) -> None:
set_seed(seed)

def calculate_probability(self, condition: str | None, output: str) -> float:
"""Calculate the probability of an output given the language model.
"""Calculate the log probability of an output given the language model.
Args:
condition: The conditioning sequence for the output.
Expand All @@ -193,7 +193,150 @@ def calculate_probability(self, condition: str | None, output: str) -> float:
Returns:
float: The probability of output x given the language model.
"""
raise NotImplementedError
try:
import numpy as np
except ImportError:
raise ImportError(
"You need to install 'numpy' package to use this function."
)

if condition:
full_input = condition + " " + output
else:
full_input = output

# Tokenize the full input (condition + output or just output)
inputs = self.tokenizer(
full_input,
return_tensors="pt",
truncation=True,
padding=False, # Avoid padding unless needed
)

# Get model outputs (logits)
full_outputs = self.model(**inputs, return_dict=True)
logits = (
full_outputs.logits.detach().cpu().numpy()
) # Convert logits to NumPy array
full_input_ids = inputs["input_ids"][0].cpu().numpy()

# define a softmax function
def softmax(logits: np.ndarray) -> np.ndarray:
exps = np.exp(
logits - np.max(logits, axis=-1, keepdims=True)
) # Stabilize softmax
return exps / np.sum(exps, axis=-1, keepdims=True)

# Calculate the probability of the output
probs = softmax(logits[0]) # Only one sequence in the batch
probs = probs[:-1, :]

# calculate the num of tokens corresponding to the output
output_ids = self.tokenizer(output)["input_ids"]
output_ids = output_ids[1:]
full_input_ids = full_input_ids[1:]
start_idx = len(full_input_ids) - len(output_ids)

# take the last # of output_tokens from the log_probs
log_probs = np.log(probs[np.arange(start_idx, len(full_input_ids)), output_ids])

# convert the log_probs to a float
return float(np.sum(log_probs))

def calculate_perplexity(self, condition: str | None, output: str) -> float:
"""Calculate the perplexity of an output given the language model.
Args:
condition: The conditioning sequence for the output.
output: The output sequence for which the probability is calculated.
Returns:
float: The perplexity of output x given the language model.
"""
try:
import numpy as np
except ImportError:
raise ImportError(
"You need to install 'numpy' package to use this function."
)

log_prob = self.calculate_probability(condition, output)
num_tokens = len(self.tokenizer(output)["input_ids"]) - 1

return float(np.exp(-log_prob / num_tokens))

def calculate_perplexity_batch(
self, condition: list[str] | None, outputs: list[str]
) -> float:
"""Calculate the perplexity of multiple outputs given the language model.
Args:
condition: The conditioning sequence for the output.
outputs: The output sequences for which the probability is calculated.
Returns:
list[float]: The perplexity of outputs given the language model.
"""
if condition:
full_inputs = [c + " " + o for c, o in zip(condition, outputs)]
else:
full_inputs = outputs

# check if the user have import Trainer, TrainingArguments, DataCollatorForLanguageModeling
try:
from datasets import Dataset
from transformers import (
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
import numpy as np
except ImportError:
print(
"Naive implementation is used. This may harm the efficiency of the calculation."
)
try:
import numpy as np
except ImportError:
raise ImportError(
"You need to install 'numpy' package to use this function."
)
return float(
np.mean(self.calculate_perplexity(None, o) for o in full_inputs)
)

# prepare the dataset
inputs = self.tokenizer(
full_inputs,
return_tensors="pt",
truncation=True,
padding=True,
)

dataset = Dataset.from_dict(inputs)
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer, mlm=False
)

training_arguments = TrainingArguments(
output_dir="trash",
per_device_eval_batch_size=1,
do_train=False,
do_eval=True,
fp16=False,
prediction_loss_only=True,
)

trainer = Trainer(
model=self.model,
data_collator=data_collator,
args=training_arguments,
eval_dataset=dataset,
)

eval_result = trainer.evaluate()

return float(np.exp(eval_result["eval_loss"]))


class HuggingFaceLMFitter:
Expand Down

0 comments on commit 38ae875

Please sign in to comment.