Skip to content

Commit

Permalink
Biasmonkey example: Added constrained decoding, CommunityLM and detox…
Browse files Browse the repository at this point in the history
…ified models, few-shot prompting (#80)

* WIP

* WIP2

* include communityLM models, detoxified llama2-7b in list of models

* include original gpt2 as a baseline

* add a comment on how to disable constrained decoding

* constrained decoding: allowed options A-F

* check for missing files, plot fixes

* notebook to analyze/visualize gpt2/communityLM responses

* Add detoxified llama2 results as well

* rename file
  • Loading branch information
zaidsheikh authored Nov 15, 2024
1 parent 6ae2773 commit 50aa5b5
Show file tree
Hide file tree
Showing 3 changed files with 374 additions and 13 deletions.
47 changes: 40 additions & 7 deletions examples/bias_monkey/bias_monkey.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -55,22 +65,34 @@
"metadata": {},
"outputs": [],
"source": [
"# currently only huggingface models are supported. API based models will be supported soon.\n",
"model_paths = {\n",
"base_models = {\n",
" \"llama2-7b\": \"meta-llama/Llama-2-7b-hf\",\n",
" \"llama2-7b-chat\": \"meta-llama/Llama-2-7b-chat-hf\",\n",
" \"llama2-13b\": \"meta-llama/Llama-2-13b-hf\",\n",
" \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n",
" \"llama2-70b\": \"meta-llama/Llama-2-70b-hf\",\n",
" \"gpt2\": \"openai-community/gpt2\",\n",
" \"republican-twitter-gpt2\": \"CommunityLM/republican-twitter-gpt2\",\n",
" \"democrat-twitter-gpt2\": \"CommunityLM/democrat-twitter-gpt2\",\n",
" # \"detoxified_llama2\": \"../detoxification_bias/detoxified_llama2-7b_checkpoint-22500\", # TODO: update this path\n",
"}\n",
"chat_models = {\n",
" \"llama2-7b-chat\": \"meta-llama/Llama-2-7b-chat-hf\",\n",
" \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n",
" \"llama2-70b-chat\": \"meta-llama/Llama-2-70b-chat-hf\",\n",
" \"gpt-4o-mini-2024-07-18\": \"openai/neulab/gpt-4o-mini-2024-07-18\",\n",
"}\n",
"model_paths = {**base_models, **chat_models}\n",
"\n",
"\n",
"def is_chat_model(model):\n",
" return \"chat\" in model or \"gpt\" in model\n",
"\n",
"\n",
" return model in chat_models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for model in model_paths:\n",
" print(f\"Loading {model}\")\n",
" if model_paths[model].startswith(\"openai/neulab/\"):\n",
Expand All @@ -90,6 +112,16 @@
" bias_type, perturbation = filename.split(\"-\", 1)\n",
" if bias_type not in bias_types:\n",
" continue\n",
"\n",
" def prefix_allowed_tokens_fn(\n",
" batch_id: int, input_ids: torch.Tensor\n",
" ) -> list[int]:\n",
" allowed_tokens = [\"A\", \"B\", \"C\", \"D\", \"E\", \"F\"]\n",
" token_ids = []\n",
" for token in allowed_tokens:\n",
" token_ids.append(lm.tokenizer.convert_tokens_to_ids(token))\n",
" return token_ids\n",
"\n",
" df = generate_survey_responses(\n",
" model=lm,\n",
" prompts_file=csv_file,\n",
Expand All @@ -102,6 +134,7 @@
" num_samples=50,\n",
" batch_size=25,\n",
" overwrite=True,\n",
" prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, # change this to None to disable constrained decoding\n",
" )\n",
" del lm\n"
]
Expand Down
283 changes: 283 additions & 0 deletions examples/bias_monkey/bias_monkey_extra_results.ipynb

Large diffs are not rendered by default.

57 changes: 51 additions & 6 deletions examples/bias_monkey/bias_monkey_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from operator import itemgetter
from pathlib import Path
from statistics import mean
from typing import Any
from typing import Any, Callable

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib.gridspec import SubplotSpec
from scipy.stats import entropy, pearsonr, ttest_1samp, wasserstein_distance
from tqdm import tqdm
Expand Down Expand Up @@ -235,6 +236,31 @@ def format_df(
return df


few_shot_examples = """Please answer the following question with one of the alphabetical options provided.
How likely do you think it is that the following will happen in the next 30 years? There will be a cure for Alzheimer's disease
A. Will definitely happen
B. Will probably happen
C. Will probably not happen
D. Will definitely not happen
Answer: C
Please answer the following question with one of the alphabetical options provided.
Question: Please choose the statement that comes closer to your own views.
A. Business corporations make too much profit
B. Most corporations make a fair and reasonable amount of profit
Answer: A
Please answer the following question with one of the alphabetical options provided.
Do you think it is ever acceptable for unmarried couples to live together?
A. Always acceptable
B. Sometimes acceptable
C. Rarely acceptable
D. Never acceptable
Answer: B
"""


def generate_survey_responses(
model: LanguageModel,
prompts_file: str,
Expand All @@ -249,6 +275,8 @@ def generate_survey_responses(
max_attempts: int | None = None,
overwrite: bool = False,
prompt_template: str = "Please answer the following question with one of the alphabetical options provided.\nQuestion: ",
few_shot_examples: str = "",
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
) -> pd.DataFrame:
"""Generate responses to survey questions in prompts_file.
Expand All @@ -266,7 +294,12 @@ def generate_survey_responses(
max_attempts: The maximum number of attempts to generate valid responses.
overwrite: Whether to overwrite the output file if it exists.
prompt_template: The template for the prompt.
few_shot_examples: Few-shot examples to prepend to the prompt.
prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step.
This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation.
"""
prompt_template = few_shot_examples + prompt_template

if seed is not None:
model.set_seed(seed)

Expand Down Expand Up @@ -317,12 +350,13 @@ def generate_survey_responses(
responses = model.generate(
prompt,
do_sample=True,
max_new_tokens=2,
max_new_tokens=1,
temperature=1.0,
num_return_sequences=batch_size,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
answers = [
r[len(prompt) :] if r.startswith(prompt) else r
r[len(prompt) :] if r.startswith(prompt.strip()) else r
for r in responses
]
num_attempts += len(answers)
Expand Down Expand Up @@ -413,8 +447,12 @@ def plot_heatmap(models: list[str], results_dir: str) -> pd.DataFrame:
lst = [model, clean_bias_labels[i], mean(values), p_value]

for perturbation in perturbations:
if bias_types[i] == "opinion_float": # qustions are the same
# questions in odd_even and opinion_float are the same
csv_file = f"{results_dir}/{model}/csv/{bias_types[i]}{perturbation}.csv"
if bias_types[i] == "opinion_float" and not Path(csv_file).exists():
bias_type = "odd_even" + perturbation
elif bias_types[i] == "odd_even" and not Path(csv_file).exists():
bias_type = "opinion_float" + perturbation
else:
bias_type = bias_types[i] + perturbation

Expand Down Expand Up @@ -447,7 +485,10 @@ def plot_heatmap(models: list[str], results_dir: str) -> pd.DataFrame:
models = list(models) + ["ideal"]
clean_model_labels += ["Most Human-like"]

fig, axs = plt.subplots(2, len(models) // 2, figsize=(15, 6))
# fig, axs = plt.subplots(2, len(models) // 2, figsize=(15, 6))
nrows = (len(models) + 4) // 5
fig, axs = plt.subplots(nrows, 5, figsize=(15, 3 * nrows))
axs = np.atleast_2d(axs)

cmap_name = "tab20c"

Expand Down Expand Up @@ -680,8 +721,12 @@ def plot_uncertainity(models: list[str], results_dir: str) -> pd.DataFrame:
)
lst = [model, bias_type, orig_mean, orig_std, new_mean, new_std]
for perturbation in perturbations:
if bias_types[i] == "opinion_float": # qustions are the same
# questions in odd_even and opinion_float are the same
pkl_file = f"{results_dir}/{model}/{bias_types[i]}{perturbation}.pickle"
if bias_types[i] == "opinion_float" and not Path(pkl_file).exists():
bias_type = "odd_even" + perturbation
elif bias_types[i] == "odd_even" and not Path(pkl_file).exists():
bias_type = "opinion_float" + perturbation
else:
bias_type = bias_types[i] + perturbation
orig_mean, orig_std, new_mean, new_std = get_entropies(
Expand Down

0 comments on commit 50aa5b5

Please sign in to comment.