diff --git a/examples/community_lm/community_lm_rag.ipynb b/examples/community_lm/community_lm_rag.ipynb new file mode 100644 index 0000000..b485f2b --- /dev/null +++ b/examples/community_lm/community_lm_rag.ipynb @@ -0,0 +1,465 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8ca8370c", + "metadata": {}, + "source": [ + "# CommunityLM RAG\n", + "\n", + "This is a replication of the experiments from [CommunityLM](https://arxiv.org/abs/2209.07065) (Jiang et al. 2022), which probes partisan worldviews from language models, based on the [original repo](https://github.com/hjian42/communitylm).\n", + "This notebook implements RAG for generating responses.\n", + "\n", + "Running all the experiments on a single GPU takes about 3-4 hours.\n", + "\n", + "Before running the notebook, please install requirements and download the data.\n", + "```bash\n", + "pip install -r requirements.txt\n", + "bash download_data.sh\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "29563e5d-41b0-4f89-8d8b-a54b40f8dfb7", + "metadata": {}, + "outputs": [], + "source": [ + "from llments.lm.base.hugging_face import HuggingFaceLM\n", + "from llments.eval.sentiment import HuggingFaceSentimentEvaluator\n", + "from llments.lm.rag import RAGLanguageModel\n", + "from llments.datastore.pyserini_datastore import PyseriniDatastore\n", + "import pandas as pd\n", + "import numpy as np\n", + "from examples.community_lm.community_lm_constants import politician_feelings, groups_feelings, anes_df\n", + "from examples.community_lm.community_lm_utils import generate_community_opinion, compute_group_stance, generate_community_opinion_rag\n", + "\n", + "device = 'cuda' # change to 'mps' if you have a mac, or 'cuda' if you have an NVIDIA GPU " + ] + }, + { + "cell_type": "markdown", + "id": "e205d12d", + "metadata": {}, + "source": [ + "## Create the Search Index\n", + "\n", + "The following code creates a search index from tweets." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60a1ebf4", + "metadata": {}, + "outputs": [], + "source": [ + "dem_datastore = PyseriniDatastore(index_path='examples/community_lm/data/dem_index', document_path='examples/community_lm/data/dem_4.7M_tweets_proc.jsonl', index_encoder='facebook/contriever', fields=['contents'], to_faiss=True, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ff914f72", + "metadata": {}, + "outputs": [], + "source": [ + "repub_datastore = PyseriniDatastore(index_path='examples/community_lm/data/repub_index', document_path='examples/community_lm/data/repub_4.7M_tweets_proc.jsonl', index_encoder='facebook/contriever', fields=['contents'], to_faiss=True, device=device)" + ] + }, + { + "cell_type": "markdown", + "id": "d0022efe", + "metadata": {}, + "source": [ + "## Generate Opinions using CommunityLM\n", + "\n", + "The following code generates opinions using CommunityLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bacd15ad", + "metadata": {}, + "outputs": [], + "source": [ + "for run in range(1, 6):\n", + " for party in ['democrat', 'republican']:\n", + " lm = HuggingFaceLM(f'CommunityLM/{party}-twitter-gpt2', device=device)\n", + " for prompt_option in ['Prompt1', 'Prompt2', 'Prompt3', 'Prompt4']:\n", + " print(f'generating {party} opinion for {prompt_option} run {run}...')\n", + " output_path = f'output/CommunityLM_{party}-twitter-gpt2/run_{run}'\n", + " generate_community_opinion(lm, prompt_option, output_path, run)" + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0f5c", + "metadata": {}, + "source": [ + "The following code generates opinions using GPT-2 with RAG." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b5685fcb", + "metadata": {}, + "outputs": [], + "source": [ + "for run in range(1, 6):\n", + " for party in ['democrat', 'republican']:\n", + " lm = HuggingFaceLM(f'openai-community/gpt2', device=device)\n", + " # lm = HuggingFaceLM(f'CommunityLM/{party}-twitter-gpt2', device=device)\n", + " if party == 'democrat':\n", + " rag_LM = RAGLanguageModel(base=lm, datastore=dem_datastore, max_results=3)\n", + " else:\n", + " rag_LM = RAGLanguageModel(base=lm, datastore=repub_datastore, max_results=3)\n", + " for prompt_option in ['Prompt1', 'Prompt2', 'Prompt3', 'Prompt4']:\n", + " print(f'generating {party} opinion for {prompt_option} run {run}...')\n", + " output_path = f'output/RAG_{party}-twitter-gpt2/run_{run}'\n", + " generate_community_opinion_rag(rag_LM, prompt_option, output_path, run)" + ] + }, + { + "cell_type": "markdown", + "id": "348fc5e7-aad4-4d1a-9436-0ae83585e8bb", + "metadata": {}, + "source": [ + "## Perform Group-level Sentiment Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d2049390", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = HuggingFaceSentimentEvaluator(\n", + " \"cardiffnlp/twitter-roberta-base-sentiment-latest\",\n", + " device=device\n", + ")\n", + "for party in ['democrat', 'republican']:\n", + " compute_group_stance(\n", + " evaluator=evaluator,\n", + " data_folder=f'output/RAG_{party}-twitter-gpt2',\n", + " output_filename=f'output/RAG_{party}-twitter-gpt2/stance_prediction.csv',\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eec53be2", + "metadata": {}, + "outputs": [], + "source": [ + "df_dem = pd.read_csv(\"output/RAG_democrat-twitter-gpt2/stance_prediction.csv\")\n", + "df_repub = pd.read_csv(\"output/RAG_republican-twitter-gpt2/stance_prediction.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "6017a1d8-ae02-4adb-b3af-3d19911a61a2", + "metadata": { + "tags": [] + }, + "source": [ + "## Preparing ANES2020 Questions\n", + "\n", + "This is data from the American National Election Study (ANES)\n", + "\n", + "Website: https://electionstudies.org/\n", + "Email: anes@electionstudies.org\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "04e5cf0c-3f2c-4cae-806a-3798f8138664", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\"examples/community_lm/data/anes_pilot_2020ets_csv.csv\")\n", + "\n", + "print(f\"Number of Rows Total {df.shape}\")\n", + "\n", + "# only look self identified partisans 2144/3080. 1: Republican; 2: Democrat\n", + "df = df[df.pid1r < 3]\n", + "df.pid1r = df.pid1r.map({1: \"Republican\", 2: \"Democrat\"})\n", + "print(f\"Number of Rows for Partisans {df.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "976e4ba7-6c58-4445-9522-fe844342df1f", + "metadata": {}, + "outputs": [], + "source": [ + "# 999 stands for missing values and 'pid1r' is the partisanship\n", + "df_politician_results = df[['pid1r']+politician_feelings+groups_feelings].replace(999, np.nan).groupby(\"pid1r\").mean().T\n", + "df_politician_results['is_repub_leading'] = (df_politician_results.Republican > df_politician_results.Democrat)\n", + "# df_politician_results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2bf14bdd", + "metadata": {}, + "outputs": [], + "source": [ + "df_politician_results" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a34e06de-bfdd-4475-a4d6-47a17d627bfb", + "metadata": {}, + "outputs": [], + "source": [ + "df_politician_results['Prompt1'] = anes_df['Prompt1'].to_list()\n", + "df_politician_results['Prompt2'] = anes_df['Prompt2'].to_list()\n", + "df_politician_results['Prompt3'] = anes_df['Prompt3'].to_list()\n", + "df_politician_results['Prompt4'] = anes_df['Prompt4'].to_list()\n", + "\n", + "df_politician_results['pid'] = df_politician_results.index\n", + "df_politician_results.to_csv(\"output/anes2020_pilot_prompt_probing.csv\", index=False)\n", + "df_politician_results" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "aabcbbde-38a0-4e7c-a0a3-93034ce589c0", + "metadata": {}, + "outputs": [], + "source": [ + "df_politician_results['diff'] = (df_politician_results.Democrat-df_politician_results.Republican).apply(abs)\n", + "df_politician_results.sort_values(by=['diff'])\n", + "df_politician_results" + ] + }, + { + "cell_type": "markdown", + "id": "48fe8318-5836-448c-bb50-301845732f53", + "metadata": {}, + "source": [ + "## Evaluate fine-tuned GPT-2 CommunityLM models\n", + "\n", + "This evaluates the sentiment of the completions generated by each model according to a sentiment classification model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "68f6811b-099e-4ba5-993c-3b7ada968f2e", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.metrics import precision_recall_fscore_support\n", + "\n", + "def compute_scores(df_anes, df_dem, df_repub):\n", + " df_repub['prediction'] = (df_repub['group_sentiment'] > df_dem['group_sentiment'])\n", + "\n", + " gold_labels = df_anes.is_repub_leading.astype(int).values\n", + " rows = []\n", + " for run in range(1, 6):\n", + " run = \"run_{}\".format(run)\n", + " for prompt_format in range(1, 5):\n", + " prompt_format = \"Prompt{}\".format(prompt_format)\n", + " df_ = df_repub[(df_repub.run == run) & (df_repub.prompt_format == prompt_format)]\n", + " pred_labels = df_.prediction.astype(int).values\n", + " acc = accuracy_score(gold_labels, pred_labels) \n", + " p, r, f1, _ = precision_recall_fscore_support(gold_labels, pred_labels, average='weighted')\n", + " rows.append([run, prompt_format, acc, p, r, f1])\n", + " df_scores = pd.DataFrame(rows, columns=[\"run\", \"prompt_format\", \"accuracy\", \"precision\", \"recall\", \"f1\"])\n", + " return df_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0d429b20", + "metadata": {}, + "outputs": [], + "source": [ + "df_repub" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d0a6ef2b-ff35-49dc-92a7-0fb984fed6ea", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\"output/anes2020_pilot_prompt_probing.csv\")\n", + "df_scores = compute_scores(df, df_dem, df_repub)\n", + "df_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "200fb627-57f4-4d73-a375-bb87e95923c2", + "metadata": {}, + "outputs": [], + "source": [ + "# extract gold ranks\n", + "df_politician_results = df_politician_results.sort_values(by=[\"pid\"])\n", + "gold_dem_rank = df_politician_results['Democrat'].rank().values\n", + "gold_repub_rank = df_politician_results['Republican'].rank().values\n", + "gold_repub_rank\n", + "\n", + "from scipy import stats\n", + "def extract_ranking(df_):\n", + " df_ = df_.sort_values(by=['question'])\n", + " return df_[df_.prompt_format == \"Prompt4\"].groupby(['question']).group_sentiment.mean().rank().values\n", + "\n", + "dem_rank = extract_ranking(df_dem)\n", + "repub_rank = extract_ranking(df_repub)\n", + "\n", + "gold_dem_rank" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "c69f6bc3-b8d2-44ad-9aab-222653191c69", + "metadata": {}, + "outputs": [], + "source": [ + "## plot the rankings\n", + "\n", + "def extract_ranking_for_politicians(df_):\n", + " df_ = df_[df_.question.isin(politician_feelings)]\n", + " df_ = df_.sort_values(by=['question', 'run'])\n", + " return df_[df_.prompt_format == \"Prompt4\"]\n", + "\n", + "df_politician_results = df_politician_results[df_politician_results.pid.isin(politician_feelings)].sort_values(by=['pid'])\n", + "df_politician_results['short_name'] = df_politician_results.Prompt1.apply(lambda x: x.split(\" \")[-1])\n", + "\n", + "dem_politician_rank = extract_ranking_for_politicians(df_dem)\n", + "df_avg = dem_politician_rank.groupby(\"question\").group_sentiment.mean().reset_index()\n", + "df_avg['group_avg_sentiment'] = df_avg['group_sentiment']\n", + "del df_avg[\"group_sentiment\"]\n", + "dem_politician_rank = dem_politician_rank.merge(df_politician_results, left_on=\"question\", right_on=\"pid\")\n", + "dem_politician_rank = dem_politician_rank.merge(df_avg, on=\"question\")\n", + "\n", + "\n", + "repub_politician_rank = extract_ranking_for_politicians(df_repub)\n", + "df_avg = repub_politician_rank.groupby(\"question\").group_sentiment.mean().reset_index()\n", + "df_avg['group_avg_sentiment'] = df_avg['group_sentiment']\n", + "del df_avg[\"group_sentiment\"]\n", + "repub_politician_rank = repub_politician_rank.merge(df_politician_results, left_on=\"question\", right_on=\"pid\")\n", + "repub_politician_rank = repub_politician_rank.merge(df_avg, on=\"question\")\n", + "\n", + "\n", + "dem_politician_rank" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "015bb056-a742-49a1-97a7-dda18d203ac8", + "metadata": {}, + "outputs": [], + "source": [ + "# df_politician_results.to_csv(\"rank_plots.csv\")\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "sns.set(rc={'figure.figsize':(5,5)})\n", + "\n", + "palette = sns.color_palette(\"Blues\",n_colors=20)\n", + "palette.reverse()\n", + "\n", + "ax = sns.barplot(data=dem_politician_rank.sort_values(by=\"group_avg_sentiment\", ascending=False), x=\"group_sentiment\", y=\"short_name\", palette=palette, estimator=np.mean, ci=90)\n", + "\n", + "ax.set_xlabel(None, fontsize=15)\n", + "ax.set_ylabel(None)\n", + "plt.tick_params(axis='both', which='major', labelsize=14)\n", + "plt.tight_layout()\n", + "plt.savefig('output/RAG_gpt2_pred_dem_rank.png', bbox_inches = \"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ea3b9c6c-a2e1-43b4-8463-caaff9653fa9", + "metadata": {}, + "outputs": [], + "source": [ + "sns.set(rc={'figure.figsize':(5,5)})\n", + "\n", + "palette = sns.color_palette(\"Blues\",n_colors=20)\n", + "palette.reverse()\n", + "\n", + "ax = sns.barplot(data=dem_politician_rank.sort_values(by=\"Democrat\", ascending=False), x=\"Democrat\", y=\"short_name\", palette=palette)\n", + "\n", + "ax.set_xlabel(None, fontsize=15)\n", + "ax.set_ylabel(None)\n", + "plt.tick_params(axis='both', which='major', labelsize=14)\n", + "plt.tight_layout()\n", + "plt.savefig('output/gold_dem_rank.png', bbox_inches = \"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "2bc980b7-aa69-4cbe-8778-f57b463fc909", + "metadata": {}, + "outputs": [], + "source": [ + "palette = sns.color_palette(\"Reds\", n_colors=20)\n", + "palette.reverse()\n", + "\n", + "ax = sns.barplot(data=repub_politician_rank.sort_values(by=\"group_avg_sentiment\", ascending=False), x=\"group_sentiment\", y=\"short_name\", palette=palette)\n", + "\n", + "ax.set_xlabel(None, fontsize=15)\n", + "ax.set_ylabel(None)\n", + "plt.tick_params(axis='both', which='major', labelsize=14)\n", + "plt.tight_layout()\n", + "plt.savefig('output/RAG_gpt2_pred_repub_rank.png', bbox_inches = \"tight\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "6c1dd3c2-dda4-482e-99b6-6ef4316155ea", + "metadata": {}, + "outputs": [], + "source": [ + "palette = sns.color_palette(\"Reds\", n_colors=20)\n", + "palette.reverse()\n", + "\n", + "ax = sns.barplot(data=repub_politician_rank.sort_values(by=\"Republican\", ascending=False), x=\"Republican\", y=\"short_name\", palette=palette)\n", + "\n", + "ax.set_xlabel(None, fontsize=15)\n", + "ax.set_ylabel(None)\n", + "plt.tick_params(axis='both', which='major', labelsize=14)\n", + "plt.tight_layout()\n", + "plt.savefig('output/gold_repub_rank.png', bbox_inches = \"tight\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/community_lm/community_lm_utils.py b/examples/community_lm/community_lm_utils.py index a494244..6b1cd0b 100644 --- a/examples/community_lm/community_lm_utils.py +++ b/examples/community_lm/community_lm_utils.py @@ -8,8 +8,63 @@ import numpy as np from llments.lm.lm import LanguageModel +from llments.lm.rag import RAGLanguageModel from llments.eval.sentiment import SentimentEvaluator +def generate_community_opinion_rag( + model: RAGLanguageModel, + prompt_option: str, + output_path: str, + seed: int, + preceding_prompt: str | None = None, + overwrite: bool = False, +) -> None: + """Generate opinions for a given prompt with RAG. + + Args: + model: The RAG language model. + prompt_option: The prompt option. + output_path: The output path. + seed: The seed for the language model. + preceding_prompt: The preceding prompt. + overwrite: Whether to overwrite the output file if it exists. + """ + model.set_seed(seed) + + questions = anes_df.pid.values.tolist() + prompts = anes_df[prompt_option].values.tolist() + + output_folder = os.path.join(output_path, prompt_option) + Path(output_folder).mkdir(parents=True, exist_ok=True) + + for question, prompt in tqdm.tqdm( + zip(questions, prompts), total=len(questions), desc="Generating opinions" + ): + output_path = os.path.join(output_folder, question + ".txt") + if os.path.exists(output_path) and not overwrite: + continue + total_prompt = ( + " ".join([preceding_prompt, prompt]) if preceding_prompt else prompt + ) + + responses = model.generate( + condition=total_prompt, + do_sample=True, + max_new_tokens=100, + temperature=1.0, + num_return_sequences=100, + ) + + responses = [x.split("\n")[0] for x in responses] + + with open(output_path, "w") as out: + for line in responses: + line = line.replace("\n", " ") + if preceding_prompt: + line = line.replace(preceding_prompt + " ", "") + out.write(line) + out.write("\n") + def generate_community_opinion( model: LanguageModel,