-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "CommunityLM Sentiment Analysis (#20)"
This reverts commit 9fcc5d5.
- Loading branch information
1 parent
9fcc5d5
commit f50fbc4
Showing
7 changed files
with
215 additions
and
603 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,6 @@ | |
"\n", | ||
"This is a replication of the experiments from [CommunityLM](https://arxiv.org/abs/2209.07065) (Jiang et al. 2022), which probes partisan worldviews from language models, based on the [original repo](https://github.com/hjian42/communitylm).\n", | ||
"\n", | ||
"Running all the experiments on a single GPU takes about 3-4 hours.\n", | ||
"\n", | ||
"Before running the notebook, please install requirements and download the data.\n", | ||
"```bash\n", | ||
"pip install -r requirements.txt\n", | ||
|
@@ -25,151 +23,265 @@ | |
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from llments.lm.base.hugging_face import HuggingFaceLM\n", | ||
"from llments.eval.sentiment import HuggingFaceSentimentEvaluator\n", | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"from community_lm_constants import politician_feelings, groups_feelings, anes_df\n", | ||
"from community_lm_utils import generate_community_opinion, compute_group_stance\n", | ||
"\n", | ||
"device = 'cuda:0' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU " | ||
"import os" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d0022efe", | ||
"metadata": {}, | ||
"id": "6017a1d8-ae02-4adb-b3af-3d19911a61a2", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"## Generate Opinions using CommunityLM\n", | ||
"## Preparing ANES2020 Questions\n", | ||
"\n", | ||
"This is data from the American National Election Study (ANES)\n", | ||
"\n", | ||
"The following code generates opinions using CommunityLM." | ||
"Website: https://electionstudies.org/\n", | ||
"Email: [email protected]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bacd15ad", | ||
"id": "04e5cf0c-3f2c-4cae-806a-3798f8138664", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for run in range(1, 6):\n", | ||
" for party in ['democrat', 'republican']:\n", | ||
" lm = HuggingFaceLM(f'CommunityLM/{party}-twitter-gpt2', device=device)\n", | ||
" for prompt_option in ['Prompt1', 'Prompt2', 'Prompt3', 'Prompt4']:\n", | ||
" print(f'generating {party} opinion for {prompt_option} run {run}...')\n", | ||
" output_path = f'output/CommunityLM_{party}-twitter-gpt2/run_{run}'\n", | ||
" generate_community_opinion(lm, prompt_option, output_path, run)" | ||
"df = pd.read_csv(\"data/anes_pilot_2020ets_csv.csv\")\n", | ||
"\n", | ||
"print(\"Number of Rows\", df.shape)\n", | ||
"\n", | ||
"politician_feelings = ['fttrump1', 'ftobama1', 'ftbiden1', 'ftwarren1', 'ftsanders1', 'ftbuttigieg1', 'ftharris1', 'ftklobuchar1',\n", | ||
" 'ftpence1', 'ftyang1', 'ftpelosi1', 'ftrubio1', 'ftocasioc1', 'fthaley1', 'ftthomas1', 'ftfauci1']\n", | ||
"\n", | ||
"groups_feelings = ['ftblack', 'ftwhite', 'fthisp', 'ftasian', 'ftillegal', 'ftfeminists', 'ftmetoo', 'fttransppl',\n", | ||
" 'ftsocialists', 'ftcapitalists', 'ftbigbusiness', 'ftlaborunions', 'ftrepublicanparty', 'ftdemocraticparty'\n", | ||
" ]\n", | ||
"\n", | ||
"partisanship = ['pid1r']\n", | ||
"\n", | ||
"# only look self identified partisans 2144/3080. 1: Repub; 2: Democrat\n", | ||
"df = df[df.pid1r < 3]\n", | ||
"df.pid1r = df.pid1r.map({1: \"Republican\", 2: \"Democrat\"})\n", | ||
"df.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "348fc5e7-aad4-4d1a-9436-0ae83585e8bb", | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7ad68fb3-e3aa-43d1-bf37-beeb92509db3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"## Perform Group-level Sentiment Analysis" | ||
"df[groups_feelings]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d2049390", | ||
"id": "976e4ba7-6c58-4445-9522-fe844342df1f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"evaluator = HuggingFaceSentimentEvaluator(\n", | ||
" \"cardiffnlp/twitter-roberta-base-sentiment-latest\",\n", | ||
" device=device\n", | ||
")\n", | ||
"for party in ['democrat', 'republican']:\n", | ||
" compute_group_stance(\n", | ||
" evaluator=evaluator,\n", | ||
" data_folder=f'output/CommunityLM_{party}-twitter-gpt2',\n", | ||
" output_filename=f'output/CommunityLM_{party}-twitter-gpt2/stance_prediction.csv',\n", | ||
" )" | ||
"# 999 stands for missing values\n", | ||
"df_politician_results = df[partisanship+politician_feelings+groups_feelings].replace(999, np.nan).groupby(\"pid1r\").mean().T\n", | ||
"df_politician_results['is_repub_leading'] = (df_politician_results.Republican > df_politician_results.Democrat)\n", | ||
"print(df_politician_results.shape)\n", | ||
"df_politician_results\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "eec53be2", | ||
"id": "af34d7ae-5766-4557-a32d-64d6b7db014b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_dem = pd.read_csv(\"output/CommunityLM_democrat-twitter-gpt2/stance_prediction.csv\")\n", | ||
"df_repub = pd.read_csv(\"output/CommunityLM_republican-twitter-gpt2/stance_prediction.csv\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6017a1d8-ae02-4adb-b3af-3d19911a61a2", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"## Preparing ANES2020 Questions\n", | ||
"\n", | ||
"This is data from the American National Election Study (ANES)\n", | ||
"\n", | ||
"Website: https://electionstudies.org/\n", | ||
"Email: [email protected]\n" | ||
"df_politician_results.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "04e5cf0c-3f2c-4cae-806a-3798f8138664", | ||
"id": "a34e06de-bfdd-4475-a4d6-47a17d627bfb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = pd.read_csv(\"data/anes_pilot_2020ets_csv.csv\")\n", | ||
"\n", | ||
"print(f\"Number of Rows Total {df.shape}\")\n", | ||
"df_politician_results['Prompt1'] = [\n", | ||
" \"Donald Trump\",\n", | ||
" \"Barack Obama\",\n", | ||
" \"Joe Biden\",\n", | ||
" \"Elizabeth Warren\",\n", | ||
" \"Bernie Sanders\",\n", | ||
" \"Pete Buttigieg\",\n", | ||
" \"Kamala Harris\",\n", | ||
" \"Amy Klobuchar\",\n", | ||
" \"Mike Pence\",\n", | ||
" \"Andrew Yang\",\n", | ||
" \"Nancy Pelosi\",\n", | ||
" \"Marco Rubio\",\n", | ||
" \"Alexandria Ocasio-Cortez\",\n", | ||
" \"Nikki Haley\",\n", | ||
" \"Clarence Thomas\",\n", | ||
" \"Dr. Anthony Fauci\",\n", | ||
" \"Black people\",\n", | ||
" \"White people\",\n", | ||
" \"Hispanic people\",\n", | ||
" \"Asian people\",\n", | ||
" \"illegal immigrants\",\n", | ||
" \"feminists\",\n", | ||
" \"the #MeToo movement\",\n", | ||
" \"transgender people\",\n", | ||
" \"socialists\",\n", | ||
" \"capitalists\",\n", | ||
" \"big business\",\n", | ||
" \"labor unions\",\n", | ||
" \"the Republican Party\",\n", | ||
" \"the Democratic Party\"\n", | ||
"]\n", | ||
"\n", | ||
"df_politician_results['Prompt2'] = [\n", | ||
" \"Donald Trump is\",\n", | ||
" \"Barack Obama is\",\n", | ||
" \"Joe Biden is\",\n", | ||
" \"Elizabeth Warren is\",\n", | ||
" \"Bernie Sanders is\",\n", | ||
" \"Pete Buttigieg is\",\n", | ||
" \"Kamala Harris is\",\n", | ||
" \"Amy Klobuchar is\",\n", | ||
" \"Mike Pence is\",\n", | ||
" \"Andrew Yang is\",\n", | ||
" \"Nancy Pelosi is\",\n", | ||
" \"Marco Rubio is\",\n", | ||
" \"Alexandria Ocasio-Cortez is\",\n", | ||
" \"Nikki Haley is\",\n", | ||
" \"Clarence Thomas is\",\n", | ||
" \"Dr. Anthony Fauci is\",\n", | ||
" \"Black people are\",\n", | ||
" \"White people are\",\n", | ||
" \"Hispanic people are\",\n", | ||
" \"Asian people are\",\n", | ||
" \"illegal immigrants are\",\n", | ||
" \"feminists are\",\n", | ||
" \"the #MeToo movement is\",\n", | ||
" \"transgender people are\",\n", | ||
" \"socialists are\",\n", | ||
" \"capitalists are\",\n", | ||
" \"big business is\",\n", | ||
" \"labor unions is\",\n", | ||
" \"the Republican Party is\",\n", | ||
" \"the Democratic Party is\"\n", | ||
"]\n", | ||
"\n", | ||
"df_politician_results['Prompt3'] = [\n", | ||
" \"Donald Trump is a\",\n", | ||
" \"Barack Obama is a\",\n", | ||
" \"Joe Biden is a\",\n", | ||
" \"Elizabeth Warren is a\",\n", | ||
" \"Bernie Sanders is a\",\n", | ||
" \"Pete Buttigieg is a\",\n", | ||
" \"Kamala Harris is a\",\n", | ||
" \"Amy Klobuchar is a\",\n", | ||
" \"Mike Pence is a\",\n", | ||
" \"Andrew Yang is a\",\n", | ||
" \"Nancy Pelosi is a\",\n", | ||
" \"Marco Rubio is a\",\n", | ||
" \"Alexandria Ocasio-Cortez is a\",\n", | ||
" \"Nikki Haley is a\",\n", | ||
" \"Clarence Thomas is a\",\n", | ||
" \"Dr. Anthony Fauci is a\",\n", | ||
" \"Black people are a\",\n", | ||
" \"White people are a\",\n", | ||
" \"Hispanic people are a\",\n", | ||
" \"Asian people are a\",\n", | ||
" \"illegal immigrants are a\",\n", | ||
" \"feminists are a\",\n", | ||
" \"the #MeToo movement is a\",\n", | ||
" \"transgender people are a\",\n", | ||
" \"socialists are a\",\n", | ||
" \"capitalists are a\",\n", | ||
" \"big business is a\",\n", | ||
" \"labor unions is a\",\n", | ||
" \"the Republican Party is a\",\n", | ||
" \"the Democratic Party is a\"\n", | ||
"]\n", | ||
"\n", | ||
"df_politician_results['Prompt4'] = [\n", | ||
" \"Donald Trump is the\",\n", | ||
" \"Barack Obama is the\",\n", | ||
" \"Joe Biden is the\",\n", | ||
" \"Elizabeth Warren is the\",\n", | ||
" \"Bernie Sanders is the\",\n", | ||
" \"Pete Buttigieg is the\",\n", | ||
" \"Kamala Harris is the\",\n", | ||
" \"Amy Klobuchar is the\",\n", | ||
" \"Mike Pence is the\",\n", | ||
" \"Andrew Yang is the\",\n", | ||
" \"Nancy Pelosi is the\",\n", | ||
" \"Marco Rubio is the\",\n", | ||
" \"Alexandria Ocasio-Cortez is the\",\n", | ||
" \"Nikki Haley is the\",\n", | ||
" \"Clarence Thomas is the\",\n", | ||
" \"Dr. Anthony Fauci is the\",\n", | ||
" \"Black people are the\",\n", | ||
" \"White people are the\",\n", | ||
" \"Hispanic people are the\",\n", | ||
" \"Asian people are the\",\n", | ||
" \"illegal immigrants are the\",\n", | ||
" \"feminists are the\",\n", | ||
" \"the #MeToo movement is the\",\n", | ||
" \"transgender people are the\",\n", | ||
" \"socialists are the\",\n", | ||
" \"capitalists are the\",\n", | ||
" \"big business is the\",\n", | ||
" \"labor unions is the\",\n", | ||
" \"the Republican Party is the\",\n", | ||
" \"the Democratic Party is the\"\n", | ||
"]\n", | ||
"\n", | ||
"# only look self identified partisans 2144/3080. 1: Republican; 2: Democrat\n", | ||
"df = df[df.pid1r < 3]\n", | ||
"df.pid1r = df.pid1r.map({1: \"Republican\", 2: \"Democrat\"})\n", | ||
"print(f\"Number of Rows for Partisans {df.shape}\")" | ||
"df_politician_results['pid'] = df_politician_results.index\n", | ||
"# make the output directory if it doesn't exist\n", | ||
"if not os.path.exists(\"output\"):\n", | ||
" os.makedirs(\"output\")\n", | ||
"df_politician_results.to_csv(\"output/anes2020_pilot_prompt_probing.csv\", index=False)\n", | ||
"df_politician_results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "976e4ba7-6c58-4445-9522-fe844342df1f", | ||
"id": "aabcbbde-38a0-4e7c-a0a3-93034ce589c0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 999 stands for missing values and 'pid1r' is the partisanship\n", | ||
"df_politician_results = df[['pid1r']+politician_feelings+groups_feelings].replace(999, np.nan).groupby(\"pid1r\").mean().T\n", | ||
"df_politician_results['is_repub_leading'] = (df_politician_results.Republican > df_politician_results.Democrat)\n", | ||
"# df_politician_results\n" | ||
"df_politician_results['diff'] = (df_politician_results.Democrat-df_politician_results.Republican).apply(abs)\n", | ||
"df_politician_results.sort_values(by=['diff'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a34e06de-bfdd-4475-a4d6-47a17d627bfb", | ||
"cell_type": "markdown", | ||
"id": "4cc6922d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_politician_results['Prompt1'] = anes_df['Prompt1'].to_list()\n", | ||
"df_politician_results['Prompt2'] = anes_df['Prompt2'].to_list()\n", | ||
"df_politician_results['Prompt3'] = anes_df['Prompt3'].to_list()\n", | ||
"df_politician_results['Prompt4'] = anes_df['Prompt4'].to_list()\n", | ||
"## Generate predictions\n", | ||
"\n", | ||
"df_politician_results['pid'] = df_politician_results.index\n", | ||
"df_politician_results.to_csv(\"output/anes2020_pilot_prompt_probing.csv\", index=False)\n", | ||
"# df_politician_results" | ||
"Generate predictions from the CommunityLM models.\n", | ||
"TODO: This is not implemented in LLMents yet, but needs to be." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aabcbbde-38a0-4e7c-a0a3-93034ce589c0", | ||
"id": "369b450b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_politician_results['diff'] = (df_politician_results.Democrat-df_politician_results.Republican).apply(abs)\n", | ||
"df_politician_results.sort_values(by=['diff'])" | ||
"# TODO: Because this is not implemented yet, we just download them from the original repo in download_data.sh\n", | ||
"df_dem = pd.read_csv(\"output/finetuned_gpt2_2019_dem/finetuned_gpt2_group_stance_predictions.csv\")\n", | ||
"df_repub = pd.read_csv(\"output/finetuned_gpt2_2019_repub/finetuned_gpt2_group_stance_predictions.csv\")" | ||
] | ||
}, | ||
{ | ||
|
@@ -212,13 +324,11 @@ | |
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0d429b20", | ||
"cell_type": "markdown", | ||
"id": "348fc5e7-aad4-4d1a-9436-0ae83585e8bb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_repub" | ||
"### using `cardiffnlp/twitter-roberta-base-sentiment-latest` sentiment classifier" | ||
] | ||
}, | ||
{ | ||
|
@@ -396,7 +506,7 @@ | |
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
"version": "3.11.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
|
Oops, something went wrong.