diff --git a/examples/community_lm/community_lm.ipynb b/examples/community_lm/community_lm.ipynb index 181ec54..be649da 100644 --- a/examples/community_lm/community_lm.ipynb +++ b/examples/community_lm/community_lm.ipynb @@ -9,8 +9,6 @@ "\n", "This is a replication of the experiments from [CommunityLM](https://arxiv.org/abs/2209.07065) (Jiang et al. 2022), which probes partisan worldviews from language models, based on the [original repo](https://github.com/hjian42/communitylm).\n", "\n", - "Running all the experiments on a single GPU takes about 3-4 hours.\n", - "\n", "Before running the notebook, please install requirements and download the data.\n", "```bash\n", "pip install -r requirements.txt\n", @@ -25,151 +23,265 @@ "metadata": {}, "outputs": [], "source": [ - "from llments.lm.base.hugging_face import HuggingFaceLM\n", - "from llments.eval.sentiment import HuggingFaceSentimentEvaluator\n", "import pandas as pd\n", "import numpy as np\n", - "from community_lm_constants import politician_feelings, groups_feelings, anes_df\n", - "from community_lm_utils import generate_community_opinion, compute_group_stance\n", - "\n", - "device = 'cuda:0' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU " + "import os" ] }, { "cell_type": "markdown", - "id": "d0022efe", - "metadata": {}, + "id": "6017a1d8-ae02-4adb-b3af-3d19911a61a2", + "metadata": { + "tags": [] + }, "source": [ - "## Generate Opinions using CommunityLM\n", + "## Preparing ANES2020 Questions\n", + "\n", + "This is data from the American National Election Study (ANES)\n", "\n", - "The following code generates opinions using CommunityLM." + "Website: https://electionstudies.org/\n", + "Email: anes@electionstudies.org\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "bacd15ad", + "id": "04e5cf0c-3f2c-4cae-806a-3798f8138664", "metadata": {}, "outputs": [], "source": [ - "for run in range(1, 6):\n", - " for party in ['democrat', 'republican']:\n", - " lm = HuggingFaceLM(f'CommunityLM/{party}-twitter-gpt2', device=device)\n", - " for prompt_option in ['Prompt1', 'Prompt2', 'Prompt3', 'Prompt4']:\n", - " print(f'generating {party} opinion for {prompt_option} run {run}...')\n", - " output_path = f'output/CommunityLM_{party}-twitter-gpt2/run_{run}'\n", - " generate_community_opinion(lm, prompt_option, output_path, run)" + "df = pd.read_csv(\"data/anes_pilot_2020ets_csv.csv\")\n", + "\n", + "print(\"Number of Rows\", df.shape)\n", + "\n", + "politician_feelings = ['fttrump1', 'ftobama1', 'ftbiden1', 'ftwarren1', 'ftsanders1', 'ftbuttigieg1', 'ftharris1', 'ftklobuchar1',\n", + " 'ftpence1', 'ftyang1', 'ftpelosi1', 'ftrubio1', 'ftocasioc1', 'fthaley1', 'ftthomas1', 'ftfauci1']\n", + "\n", + "groups_feelings = ['ftblack', 'ftwhite', 'fthisp', 'ftasian', 'ftillegal', 'ftfeminists', 'ftmetoo', 'fttransppl',\n", + " 'ftsocialists', 'ftcapitalists', 'ftbigbusiness', 'ftlaborunions', 'ftrepublicanparty', 'ftdemocraticparty'\n", + " ]\n", + "\n", + "partisanship = ['pid1r']\n", + "\n", + "# only look self identified partisans 2144/3080. 1: Repub; 2: Democrat\n", + "df = df[df.pid1r < 3]\n", + "df.pid1r = df.pid1r.map({1: \"Republican\", 2: \"Democrat\"})\n", + "df.shape" ] }, { - "cell_type": "markdown", - "id": "348fc5e7-aad4-4d1a-9436-0ae83585e8bb", + "cell_type": "code", + "execution_count": null, + "id": "7ad68fb3-e3aa-43d1-bf37-beeb92509db3", "metadata": {}, + "outputs": [], "source": [ - "## Perform Group-level Sentiment Analysis" + "df[groups_feelings]" ] }, { "cell_type": "code", "execution_count": null, - "id": "d2049390", + "id": "976e4ba7-6c58-4445-9522-fe844342df1f", "metadata": {}, "outputs": [], "source": [ - "evaluator = HuggingFaceSentimentEvaluator(\n", - " \"cardiffnlp/twitter-roberta-base-sentiment-latest\",\n", - " device=device\n", - ")\n", - "for party in ['democrat', 'republican']:\n", - " compute_group_stance(\n", - " evaluator=evaluator,\n", - " data_folder=f'output/CommunityLM_{party}-twitter-gpt2',\n", - " output_filename=f'output/CommunityLM_{party}-twitter-gpt2/stance_prediction.csv',\n", - " )" + "# 999 stands for missing values\n", + "df_politician_results = df[partisanship+politician_feelings+groups_feelings].replace(999, np.nan).groupby(\"pid1r\").mean().T\n", + "df_politician_results['is_repub_leading'] = (df_politician_results.Republican > df_politician_results.Democrat)\n", + "print(df_politician_results.shape)\n", + "df_politician_results\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "eec53be2", + "id": "af34d7ae-5766-4557-a32d-64d6b7db014b", "metadata": {}, "outputs": [], "source": [ - "df_dem = pd.read_csv(\"output/CommunityLM_democrat-twitter-gpt2/stance_prediction.csv\")\n", - "df_repub = pd.read_csv(\"output/CommunityLM_republican-twitter-gpt2/stance_prediction.csv\")" - ] - }, - { - "cell_type": "markdown", - "id": "6017a1d8-ae02-4adb-b3af-3d19911a61a2", - "metadata": { - "tags": [] - }, - "source": [ - "## Preparing ANES2020 Questions\n", - "\n", - "This is data from the American National Election Study (ANES)\n", - "\n", - "Website: https://electionstudies.org/\n", - "Email: anes@electionstudies.org\n" + "df_politician_results.shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "04e5cf0c-3f2c-4cae-806a-3798f8138664", + "id": "a34e06de-bfdd-4475-a4d6-47a17d627bfb", "metadata": {}, "outputs": [], "source": [ - "df = pd.read_csv(\"data/anes_pilot_2020ets_csv.csv\")\n", - "\n", - "print(f\"Number of Rows Total {df.shape}\")\n", + "df_politician_results['Prompt1'] = [\n", + " \"Donald Trump\",\n", + " \"Barack Obama\",\n", + " \"Joe Biden\",\n", + " \"Elizabeth Warren\",\n", + " \"Bernie Sanders\",\n", + " \"Pete Buttigieg\",\n", + " \"Kamala Harris\",\n", + " \"Amy Klobuchar\",\n", + " \"Mike Pence\",\n", + " \"Andrew Yang\",\n", + " \"Nancy Pelosi\",\n", + " \"Marco Rubio\",\n", + " \"Alexandria Ocasio-Cortez\",\n", + " \"Nikki Haley\",\n", + " \"Clarence Thomas\",\n", + " \"Dr. Anthony Fauci\",\n", + " \"Black people\",\n", + " \"White people\",\n", + " \"Hispanic people\",\n", + " \"Asian people\",\n", + " \"illegal immigrants\",\n", + " \"feminists\",\n", + " \"the #MeToo movement\",\n", + " \"transgender people\",\n", + " \"socialists\",\n", + " \"capitalists\",\n", + " \"big business\",\n", + " \"labor unions\",\n", + " \"the Republican Party\",\n", + " \"the Democratic Party\"\n", + "]\n", + "\n", + "df_politician_results['Prompt2'] = [\n", + " \"Donald Trump is\",\n", + " \"Barack Obama is\",\n", + " \"Joe Biden is\",\n", + " \"Elizabeth Warren is\",\n", + " \"Bernie Sanders is\",\n", + " \"Pete Buttigieg is\",\n", + " \"Kamala Harris is\",\n", + " \"Amy Klobuchar is\",\n", + " \"Mike Pence is\",\n", + " \"Andrew Yang is\",\n", + " \"Nancy Pelosi is\",\n", + " \"Marco Rubio is\",\n", + " \"Alexandria Ocasio-Cortez is\",\n", + " \"Nikki Haley is\",\n", + " \"Clarence Thomas is\",\n", + " \"Dr. Anthony Fauci is\",\n", + " \"Black people are\",\n", + " \"White people are\",\n", + " \"Hispanic people are\",\n", + " \"Asian people are\",\n", + " \"illegal immigrants are\",\n", + " \"feminists are\",\n", + " \"the #MeToo movement is\",\n", + " \"transgender people are\",\n", + " \"socialists are\",\n", + " \"capitalists are\",\n", + " \"big business is\",\n", + " \"labor unions is\",\n", + " \"the Republican Party is\",\n", + " \"the Democratic Party is\"\n", + "]\n", + "\n", + "df_politician_results['Prompt3'] = [\n", + " \"Donald Trump is a\",\n", + " \"Barack Obama is a\",\n", + " \"Joe Biden is a\",\n", + " \"Elizabeth Warren is a\",\n", + " \"Bernie Sanders is a\",\n", + " \"Pete Buttigieg is a\",\n", + " \"Kamala Harris is a\",\n", + " \"Amy Klobuchar is a\",\n", + " \"Mike Pence is a\",\n", + " \"Andrew Yang is a\",\n", + " \"Nancy Pelosi is a\",\n", + " \"Marco Rubio is a\",\n", + " \"Alexandria Ocasio-Cortez is a\",\n", + " \"Nikki Haley is a\",\n", + " \"Clarence Thomas is a\",\n", + " \"Dr. Anthony Fauci is a\",\n", + " \"Black people are a\",\n", + " \"White people are a\",\n", + " \"Hispanic people are a\",\n", + " \"Asian people are a\",\n", + " \"illegal immigrants are a\",\n", + " \"feminists are a\",\n", + " \"the #MeToo movement is a\",\n", + " \"transgender people are a\",\n", + " \"socialists are a\",\n", + " \"capitalists are a\",\n", + " \"big business is a\",\n", + " \"labor unions is a\",\n", + " \"the Republican Party is a\",\n", + " \"the Democratic Party is a\"\n", + "]\n", + "\n", + "df_politician_results['Prompt4'] = [\n", + " \"Donald Trump is the\",\n", + " \"Barack Obama is the\",\n", + " \"Joe Biden is the\",\n", + " \"Elizabeth Warren is the\",\n", + " \"Bernie Sanders is the\",\n", + " \"Pete Buttigieg is the\",\n", + " \"Kamala Harris is the\",\n", + " \"Amy Klobuchar is the\",\n", + " \"Mike Pence is the\",\n", + " \"Andrew Yang is the\",\n", + " \"Nancy Pelosi is the\",\n", + " \"Marco Rubio is the\",\n", + " \"Alexandria Ocasio-Cortez is the\",\n", + " \"Nikki Haley is the\",\n", + " \"Clarence Thomas is the\",\n", + " \"Dr. Anthony Fauci is the\",\n", + " \"Black people are the\",\n", + " \"White people are the\",\n", + " \"Hispanic people are the\",\n", + " \"Asian people are the\",\n", + " \"illegal immigrants are the\",\n", + " \"feminists are the\",\n", + " \"the #MeToo movement is the\",\n", + " \"transgender people are the\",\n", + " \"socialists are the\",\n", + " \"capitalists are the\",\n", + " \"big business is the\",\n", + " \"labor unions is the\",\n", + " \"the Republican Party is the\",\n", + " \"the Democratic Party is the\"\n", + "]\n", "\n", - "# only look self identified partisans 2144/3080. 1: Republican; 2: Democrat\n", - "df = df[df.pid1r < 3]\n", - "df.pid1r = df.pid1r.map({1: \"Republican\", 2: \"Democrat\"})\n", - "print(f\"Number of Rows for Partisans {df.shape}\")" + "df_politician_results['pid'] = df_politician_results.index\n", + "# make the output directory if it doesn't exist\n", + "if not os.path.exists(\"output\"):\n", + " os.makedirs(\"output\")\n", + "df_politician_results.to_csv(\"output/anes2020_pilot_prompt_probing.csv\", index=False)\n", + "df_politician_results" ] }, { "cell_type": "code", "execution_count": null, - "id": "976e4ba7-6c58-4445-9522-fe844342df1f", + "id": "aabcbbde-38a0-4e7c-a0a3-93034ce589c0", "metadata": {}, "outputs": [], "source": [ - "# 999 stands for missing values and 'pid1r' is the partisanship\n", - "df_politician_results = df[['pid1r']+politician_feelings+groups_feelings].replace(999, np.nan).groupby(\"pid1r\").mean().T\n", - "df_politician_results['is_repub_leading'] = (df_politician_results.Republican > df_politician_results.Democrat)\n", - "# df_politician_results\n" + "df_politician_results['diff'] = (df_politician_results.Democrat-df_politician_results.Republican).apply(abs)\n", + "df_politician_results.sort_values(by=['diff'])" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "a34e06de-bfdd-4475-a4d6-47a17d627bfb", + "cell_type": "markdown", + "id": "4cc6922d", "metadata": {}, - "outputs": [], "source": [ - "df_politician_results['Prompt1'] = anes_df['Prompt1'].to_list()\n", - "df_politician_results['Prompt2'] = anes_df['Prompt2'].to_list()\n", - "df_politician_results['Prompt3'] = anes_df['Prompt3'].to_list()\n", - "df_politician_results['Prompt4'] = anes_df['Prompt4'].to_list()\n", + "## Generate predictions\n", "\n", - "df_politician_results['pid'] = df_politician_results.index\n", - "df_politician_results.to_csv(\"output/anes2020_pilot_prompt_probing.csv\", index=False)\n", - "# df_politician_results" + "Generate predictions from the CommunityLM models.\n", + "TODO: This is not implemented in LLMents yet, but needs to be." ] }, { "cell_type": "code", "execution_count": null, - "id": "aabcbbde-38a0-4e7c-a0a3-93034ce589c0", + "id": "369b450b", "metadata": {}, "outputs": [], "source": [ - "df_politician_results['diff'] = (df_politician_results.Democrat-df_politician_results.Republican).apply(abs)\n", - "df_politician_results.sort_values(by=['diff'])" + "# TODO: Because this is not implemented yet, we just download them from the original repo in download_data.sh\n", + "df_dem = pd.read_csv(\"output/finetuned_gpt2_2019_dem/finetuned_gpt2_group_stance_predictions.csv\")\n", + "df_repub = pd.read_csv(\"output/finetuned_gpt2_2019_repub/finetuned_gpt2_group_stance_predictions.csv\")" ] }, { @@ -212,13 +324,11 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "0d429b20", + "cell_type": "markdown", + "id": "348fc5e7-aad4-4d1a-9436-0ae83585e8bb", "metadata": {}, - "outputs": [], "source": [ - "df_repub" + "### using `cardiffnlp/twitter-roberta-base-sentiment-latest` sentiment classifier" ] }, { @@ -396,7 +506,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/examples/community_lm/community_lm_constants.py b/examples/community_lm/community_lm_constants.py deleted file mode 100644 index afc9215..0000000 --- a/examples/community_lm/community_lm_constants.py +++ /dev/null @@ -1,264 +0,0 @@ -import pandas as pd - -anes_data: dict[str, list] = { - "Democrat": [ - 17.664377682403433, - 81.29270386266094, - 66.49742268041237, - 59.84347826086957, - 63.769827586206894, - 52.78660714285714, - 52.12088888888889, - 50.04343891402715, - 24.091845493562232, - 49.27913279132791, - 61.758175559380376, - 31.52205220522052, - 50.59660107334526, - 29.86053130929791, - 29.95, - 66.66987740805604, - 76.21963824289406, - 71.25, - 71.26936316695353, - 68.94554883318928, - 56.1698275862069, - 61.96793760831889, - 63.73977371627503, - 63.21761658031088, - 54.00086880973067, - 46.67565217391304, - 43.1421143847487, - 60.67247386759582, - 25.01643598615917, - 71.23768366464995, - ], - "Republican": [ - 77.8331627430911, - 29.99282051282051, - 24.401234567901238, - 20.4576802507837, - 20.50359712230216, - 21.66344086021505, - 18.63300760043431, - 22.171366594360087, - 71.12152420185376, - 29.19117647058824, - 16.098663926002054, - 43.00854700854701, - 16.48975188781014, - 47.06971428571429, - 48.63186813186813, - 58.27589852008457, - 66.51030927835052, - 77.15653964984553, - 63.5958762886598, - 63.4437564499484, - 26.246646026831787, - 33.92037228541882, - 32.73085106382979, - 35.06198347107438, - 24.106029106029105, - 60.53305351521511, - 57.84782608695652, - 44.870466321243526, - 74.46564102564102, - 24.34426229508197, - ], - "is_repub_leading": [ - True, - False, - False, - False, - False, - False, - False, - False, - True, - False, - False, - True, - False, - True, - True, - False, - False, - True, - False, - False, - False, - False, - False, - False, - False, - True, - True, - False, - True, - False, - ], - "Prompt1": [ - "Donald Trump", - "Barack Obama", - "Joe Biden", - "Elizabeth Warren", - "Bernie Sanders", - "Pete Buttigieg", - "Kamala Harris", - "Amy Klobuchar", - "Mike Pence", - "Andrew Yang", - "Nancy Pelosi", - "Marco Rubio", - "Alexandria Ocasio-Cortez", - "Nikki Haley", - "Clarence Thomas", - "Dr. Anthony Fauci", - "Black people", - "White people", - "Hispanic people", - "Asian people", - "illegal immigrants", - "feminists", - "the #MeToo movement", - "transgender people", - "socialists", - "capitalists", - "big business", - "labor unions", - "the Republican Party", - "the Democratic Party", - ], - "Prompt2": [ - "Donald Trump is", - "Barack Obama is", - "Joe Biden is", - "Elizabeth Warren is", - "Bernie Sanders is", - "Pete Buttigieg is", - "Kamala Harris is", - "Amy Klobuchar is", - "Mike Pence is", - "Andrew Yang is", - "Nancy Pelosi is", - "Marco Rubio is", - "Alexandria Ocasio-Cortez is", - "Nikki Haley is", - "Clarence Thomas is", - "Dr. Anthony Fauci is", - "Black people are", - "White people are", - "Hispanic people are", - "Asian people are", - "illegal immigrants are", - "feminists are", - "the #MeToo movement is", - "transgender people are", - "socialists are", - "capitalists are", - "big business is", - "labor unions is", - "the Republican Party is", - "the Democratic Party is", - ], - "Prompt3": [ - "Donald Trump is a", - "Barack Obama is a", - "Joe Biden is a", - "Elizabeth Warren is a", - "Bernie Sanders is a", - "Pete Buttigieg is a", - "Kamala Harris is a", - "Amy Klobuchar is a", - "Mike Pence is a", - "Andrew Yang is a", - "Nancy Pelosi is a", - "Marco Rubio is a", - "Alexandria Ocasio-Cortez is a", - "Nikki Haley is a", - "Clarence Thomas is a", - "Dr. Anthony Fauci is a", - "Black people are a", - "White people are a", - "Hispanic people are a", - "Asian people are a", - "illegal immigrants are a", - "feminists are a", - "the #MeToo movement is a", - "transgender people are a", - "socialists are a", - "capitalists are a", - "big business is a", - "labor unions is a", - "the Republican Party is a", - "the Democratic Party is a", - ], - "Prompt4": [ - "Donald Trump is the", - "Barack Obama is the", - "Joe Biden is the", - "Elizabeth Warren is the", - "Bernie Sanders is the", - "Pete Buttigieg is the", - "Kamala Harris is the", - "Amy Klobuchar is the", - "Mike Pence is the", - "Andrew Yang is the", - "Nancy Pelosi is the", - "Marco Rubio is the", - "Alexandria Ocasio-Cortez is the", - "Nikki Haley is the", - "Clarence Thomas is the", - "Dr. Anthony Fauci is the", - "Black people are the", - "White people are the", - "Hispanic people are the", - "Asian people are the", - "illegal immigrants are the", - "feminists are the", - "the #MeToo movement is the", - "transgender people are the", - "socialists are the", - "capitalists are the", - "big business is the", - "labor unions is the", - "the Republican Party is the", - "the Democratic Party is the", - ], - "pid": [ - "fttrump1", - "ftobama1", - "ftbiden1", - "ftwarren1", - "ftsanders1", - "ftbuttigieg1", - "ftharris1", - "ftklobuchar1", - "ftpence1", - "ftyang1", - "ftpelosi1", - "ftrubio1", - "ftocasioc1", - "fthaley1", - "ftthomas1", - "ftfauci1", - "ftblack", - "ftwhite", - "fthisp", - "ftasian", - "ftillegal", - "ftfeminists", - "ftmetoo", - "fttransppl", - "ftsocialists", - "ftcapitalists", - "ftbigbusiness", - "ftlaborunions", - "ftrepublicanparty", - "ftdemocraticparty", - ], -} -anes_df: pd.DataFrame = pd.DataFrame(anes_data) - -politician_feelings: list[str] = anes_data["pid"][:16] -groups_feelings: list[str] = anes_data["pid"][16:] diff --git a/examples/community_lm/community_lm_utils.py b/examples/community_lm/community_lm_utils.py deleted file mode 100644 index 12241fc..0000000 --- a/examples/community_lm/community_lm_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -""" - -This script uses community GPT models to generate opinions given prompts. - -""" - -import os -import tqdm -from pathlib import Path -from community_lm_constants import anes_df -import pandas as pd -import numpy as np - -from llments.lm.lm import LanguageModel -from llments.eval.sentiment import SentimentEvaluator - - -def generate_community_opinion( - model: LanguageModel, - prompt_option: str, - output_path: str, - seed: int, - preceding_prompt: str | None = None, - overwrite: bool = False, -): - model.set_seed(seed) - - questions = anes_df.pid.values.tolist() - prompts = anes_df[prompt_option].values.tolist() - - output_folder = os.path.join(output_path, prompt_option) - Path(output_folder).mkdir(parents=True, exist_ok=True) - - for question, prompt in tqdm.tqdm( - zip(questions, prompts), total=len(questions), desc="Generating opinions" - ): - output_path = os.path.join(output_folder, question + ".txt") - if os.path.exists(output_path) and not overwrite: - continue - total_prompt = ( - " ".join([preceding_prompt, prompt]) if preceding_prompt else prompt - ) - responses = model.generate( - total_prompt, - do_sample=True, - max_length=50, - temperature=1.0, - num_return_sequences=1000, - ) - responses = [x.split("\n")[0] for x in responses] - - with open(output_path, "w") as out: - for line in responses: - line = line.replace("\n", " ") - if preceding_prompt: - line = line.replace(preceding_prompt + " ", "") - out.write(line) - out.write("\n") - - -def compute_group_stance( - evaluator: SentimentEvaluator, - data_folder: str, - output_filename: str, - overwrite: bool = False, -): - if not overwrite and os.path.exists(output_filename): - return - - questions = anes_df.pid.values.tolist() - model_name = data_folder.strip("/").split("/")[-1] - - columns = ["model_name", "run", "prompt_format", "question", "group_sentiment"] - rows = [] - for run_id in range(1, 6): - run_format = f"run_{run_id}" - print(f"Processing {run_format} ...") - for prompt_id in range(1, 5): - prompt_format = f"Prompt{prompt_id}" - for question in tqdm.tqdm(questions, "Processing questions"): - file_name = os.path.join( - data_folder, run_format, prompt_format, question + ".txt" - ) - with open(file_name) as f: - file_lines = f.readlines() - sentiment_vals = evaluator.evaluate_batch( - file_lines, minibatch_size=len(file_lines) - ) - group_sentiment = np.mean(sentiment_vals) * 100 - rows.append( - [model_name, run_format, prompt_format, question, group_sentiment] - ) - - df = pd.DataFrame(rows, columns=columns) - df.to_csv(output_filename) diff --git a/llments/eval/eval.py b/llments/eval/eval.py index 54f9989..829e0ad 100644 --- a/llments/eval/eval.py +++ b/llments/eval/eval.py @@ -1,10 +1,23 @@ import abc import dataclasses -import tqdm + + +class PairwiseEvaluator: + """A class that defines an evaluation function, assessing a hypothesized string.""" + + @abc.abstractmethod + def evaluate(self, hyp: str, ref: str) -> float: + """Returns an evaluation score between 0 and 1 for two strings. + + Args: + hyp: The hypothesized string (e.g. a system output). + ref: The reference string (e.g. a gold-standard output). + """ + ... @dataclasses.dataclass -class EvalContext: +class EvaluatorMetadata: ... @@ -12,44 +25,11 @@ class Evaluator: """A class that defines an evaluation function, assessing a hypothesized string.""" @abc.abstractmethod - def evaluate(self, hyp: str, context: EvalContext | None = None) -> float: - """Returns an evaluation score (usually between 0-1) conditioned on data. + def evaluate(self, hyp: str, ref: EvaluatorMetadata) -> float: + """Returns an evaluation score between 0 and 1 for two strings. Args: hyp: The hypothesized string (e.g. a system output). - context: The reference context to condition on. - - Returns: - The evaluation score, usually between 0 and 1 inclusive. + ref: The reference string (e.g. a gold-standard output). """ ... - - def evaluate_batch( - self, - hyps: list[str], - contexts: list[EvalContext] | None = None, - minibatch_size: int | None = None, - show_progress: bool = False, - ) -> list[float]: - """Evaluate many hypotheses at once. - - Args: - hyps: A list of hypothesized strings (e.g. system outputs). - context: The reference context to condition on. - minibatch_size: The size of the minibatch to use, - None guesses a good size automatically. - show_progress: Whether to show a progress bar. - - Returns: - A list of evaluation scores, usually between 0 and 1 inclusive. - """ - if show_progress: - hyps = tqdm.tqdm(hyps, desc="Evaluating") - if contexts is not None: - if len(hyps) != len(contexts): - raise ValueError( - "The number of contexts must match the number of hypotheses." - ) - return [self.evaluate(hyp, context) for hyp, context in zip(hyps, contexts)] - else: - return [self.evaluate(hyp) for hyp in hyps] diff --git a/llments/eval/sentiment.py b/llments/eval/sentiment.py deleted file mode 100644 index 8a15604..0000000 --- a/llments/eval/sentiment.py +++ /dev/null @@ -1,104 +0,0 @@ -import abc -import warnings - -import tqdm -from llments.eval.eval import Evaluator, EvalContext - - -class SentimentEvaluator(Evaluator): - """An evaluator that evaluates the sentiment of an output.""" - - @abc.abstractmethod - def evaluate(self, hyp: str, context: EvalContext | None = None) -> float: - """Returns a sentiment score (usually between 0-1) conditioned on data. - - Args: - hyp: The hypothesized string (e.g. a system output). - context: Any additional context about the evaluation. - - Returns: - The evaluation score, usually between 0 and 1 inclusive. - """ - ... - - -class HuggingFaceSentimentEvaluator(SentimentEvaluator): - """An evaluator that uses HuggingFace to evaluate the sentiment of an output.""" - - def __init__(self, model: str | None = None, device: str | None = None): - """Initialize a HuggingFaceSentimentEvaluator. - - Args: - model: The name of the model. - device: The device to run the model on. - """ - try: - from transformers import pipeline - except ImportError: - raise ImportError( - "HuggingFaceSentimentEvaluator requires the `transformers` library." - ) - self.sentiment_pipeline = pipeline( - "sentiment-analysis", - model=model, - tokenizer=model, - device=device, - ) - self.sentiment_dict = {"negative": 0, "positive": 1, "neutral": 0.5} - - def evaluate(self, hyp: str, context: EvalContext | None = None) -> float: - """Returns a sentiment score (usually between 0-1) conditioned on data. - - Args: - hyp: The hypothesized string (e.g. a system output). - context: Not used. - - Returns: - The evaluation score, usually between 0 and 1 inclusive. - """ - if context is not None: - warnings.warn( - "HuggingFaceSentimentEvaluator does not use the context argument.", - ) - pred = self.sentiment_pipeline(hyp) - return self.sentiment_dict[pred["label"].lower()] - - def evaluate_batch( - self, - hyps: list[str], - contexts: list[EvalContext] | None = None, - minibatch_size: int | None = None, - show_progress: bool = False, - ) -> list[float]: - """Evaluate many hypotheses at once. - - Args: - hyps: A list of hypothesized strings (e.g. system outputs). - context: Not used. - show_progress: Whether to show a progress bar. - - Returns: - A list of evaluation scores, usually between 0 and 1 inclusive. - """ - if contexts is not None: - warnings.warn( - "HuggingFaceSentimentEvaluator does not use the context argument.", - ) - # TODO: we could have more intelligent guessing here - if minibatch_size is None: - minibatch_size = 128 - minibatch = [] - all_scores = [] - starts = range(0, len(hyps), minibatch_size) - if show_progress: - starts = tqdm.tqdm( - starts, desc=f"Analyzing batches of size {minibatch_size}" - ) - for i in starts: - minibatch = hyps[i : i + minibatch_size] - minibatch_scores = [ - self.sentiment_dict[x["label"]] - for x in self.sentiment_pipeline(minibatch) - ] - all_scores.extend(minibatch_scores) - return all_scores diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index a1e2b6a..d26dd66 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -1,4 +1,5 @@ from llments.lm.lm import LanguageModel +from transformers import pipeline, set_seed class HuggingFaceLM(LanguageModel): @@ -13,15 +14,7 @@ def __init__( model: The name of the model. device: The device to run the model on. """ - try: - from transformers import pipeline, TextGenerationPipeline - except ImportError: - raise ImportError( - "You need to install the `transformers` package to use this class." - ) - self.text_generator: TextGenerationPipeline = pipeline( - "text-generation", model=model, device=device - ) + self.text_generator = pipeline("text-generation", model=model, device=device) def fit( self, target: LanguageModel, task_description: str | None = None @@ -68,7 +61,6 @@ def generate( temperature=temperature, num_return_sequences=num_return_sequences, clean_up_tokenization_spaces=True, - truncation=max_length is not None, ) return [res["generated_text"] for res in results] @@ -78,12 +70,6 @@ def set_seed(self, seed: int): Args: seed: The seed to set for the language model. """ - try: - from transformers import set_seed - except ImportError: - raise ImportError( - "You need to install the `transformers` package to use this class." - ) set_seed(seed) diff --git a/pyproject.toml b/pyproject.toml index d1a57d4..73b392d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ classifiers = [ ] dependencies = [ "pandas", - "tqdm", ] dynamic = ["version"]