From eacaa4c418eea1ba72c252dca6b0ab51669d62c3 Mon Sep 17 00:00:00 2001 From: Zaid Sheikh Date: Fri, 27 Sep 2024 11:21:16 -0400 Subject: [PATCH] Add support for API-based LLMs in biasmonkey notebook (#71) --- examples/bias_monkey/bias_monkey.ipynb | 35 ++++++++--- examples/bias_monkey/bias_monkey_utils.py | 6 +- llments/lm/base/api.py | 74 ++++++++++++++--------- 3 files changed, 75 insertions(+), 40 deletions(-) diff --git a/examples/bias_monkey/bias_monkey.ipynb b/examples/bias_monkey/bias_monkey.ipynb index 5e5fd9a..15f544c 100644 --- a/examples/bias_monkey/bias_monkey.ipynb +++ b/examples/bias_monkey/bias_monkey.ipynb @@ -18,13 +18,15 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from glob import glob\n", + "\n", "from llments.lm.base.hugging_face import HuggingFaceLM\n", + "from llments.lm.base.api import APIBasedLM\n", "import torch, gc\n", "from pathlib import Path\n", "from bias_monkey_utils import (\n", @@ -61,23 +63,42 @@ " \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n", " \"llama2-70b\": \"meta-llama/Llama-2-70b-hf\",\n", " \"llama2-70b-chat\": \"meta-llama/Llama-2-70b-chat-hf\",\n", + " \"gpt-4o-mini-2024-07-18\": \"openai/neulab/gpt-4o-mini-2024-07-18\",\n", "}\n", + "\n", + "\n", + "def is_chat_model(model):\n", + " return \"chat\" in model or \"gpt\" in model\n", + "\n", + "\n", "for model in model_paths:\n", " print(f\"Loading {model}\")\n", - " lm = HuggingFaceLM(model_paths[model], device=device)\n", + " if model_paths[model].startswith(\"openai/neulab/\"):\n", + " assert os.environ.get(\"OPENAI_API_KEY\"), \"Please set OPENAI_API_KEY\"\n", + " assert os.environ.get(\"LITELLM_API_BASE\"), \"Please set LITELLM_API_BASE\"\n", + " lm = APIBasedLM(\n", + " model_name=model_paths[model], api_base=os.environ[\"LITELLM_API_BASE\"]\n", + " )\n", + " else:\n", + " lm = HuggingFaceLM(model_paths[model], device=device)\n", " for csv_file in sorted(glob(\"BiasMonkey/prompts/*.csv\")):\n", " print(f\"Processing {csv_file}\")\n", " filename = os.path.basename(csv_file.removesuffix(\".csv\"))\n", - " bias_type, perturbation = filename.split(\"-\")\n", + " if \"-\" not in filename:\n", + " bias_type, perturbation = filename, None\n", + " else:\n", + " bias_type, perturbation = filename.split(\"-\", 1)\n", + " if bias_type not in bias_types:\n", + " continue\n", " df = generate_survey_responses(\n", " model=lm,\n", " prompts_file=csv_file,\n", - " bias_type=bias_type,\n", - " perturbation=perturbation,\n", " output_path=f\"results/{model}/{filename}.pickle\",\n", " output_csv=f\"results/{model}/csv/{filename}.csv\",\n", - " is_chat_model=\"chat\" in model,\n", - " seed=1,\n", + " bias_type=bias_type,\n", + " perturbation=perturbation,\n", + " is_chat_model=is_chat_model(model),\n", + " seed=None if is_chat_model(model) else 1,\n", " num_samples=50,\n", " batch_size=25,\n", " overwrite=True,\n", diff --git a/examples/bias_monkey/bias_monkey_utils.py b/examples/bias_monkey/bias_monkey_utils.py index d271644..c240641 100644 --- a/examples/bias_monkey/bias_monkey_utils.py +++ b/examples/bias_monkey/bias_monkey_utils.py @@ -238,10 +238,10 @@ def format_df( def generate_survey_responses( model: LanguageModel, prompts_file: str, - bias_type: str, - perturbation: str, output_path: str, output_csv: str, + bias_type: str, + perturbation: str | None = None, is_chat_model: bool = True, seed: int | None = None, num_samples: int = 50, @@ -256,7 +256,7 @@ def generate_survey_responses( model: The language model. prompts_file: The csv file containing the prompts. bias_type: one of ["acquiescence", "allow_forbid", "odd_even", "response_order", "opinion_float"] - perturbation: one of ["key_typo", "middle_random", "letter_swap"] + perturbation: one of ["key_typo", "middle_random", "letter_swap"] or None output_path: output path (pickle file). output_csv: output csv file. is_chat_model: Whether the model is a chat model. diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 3d54190..525cf54 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -1,10 +1,13 @@ """Base class for API-Based Language Models.""" -import os import abc +import os import warnings + +from litellm import ModelResponse, batch_completion, completion + from llments.lm.lm import LanguageModel -from litellm import completion, batch_completion, ModelResponse + class APIBasedLM(LanguageModel): """Base class for API-Based Language Models. @@ -54,8 +57,8 @@ def generate( max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, - num_return_sequences: int = 1 - ) -> list[str]: + num_return_sequences: int = 1, + ) -> list[str]: """Generate a response based on the given prompt. This method sends a prompt to the language model API and retrieves @@ -68,30 +71,36 @@ def generate( max_length (int): The maximum length of the output sequence, (defaults to model max). max_new_tokens (float): The maximum number of tokens to generate in the chat completion. - temperature (float): The sampling temperature to be used, between 0 and 2. + temperature (float): The sampling temperature to be used, between 0 and 2. num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: str: Sampled output sequences from the language model. """ if condition is not None: - warnings.warn("A non-default value for 'condition' was provided.", UserWarning) + warnings.warn( + "A non-default value for 'condition' was provided.", UserWarning + ) if do_sample: - warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) + warnings.warn( + "A non-default value for 'do_sample' was provided.", UserWarning + ) if max_length is not None: - warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) - + warnings.warn( + "A non-default value for 'max_length' was provided.", UserWarning + ) + responses = [] response = completion( - model = self.model_name, - temperature = temperature, - max_tokens = max_new_tokens, - n = num_return_sequences, - api_base = self.api_base, - messages=[{"content": condition, "role": "user"}] + model=self.model_name, + temperature=temperature, + max_tokens=max_new_tokens, + n=num_return_sequences, + api_base=self.api_base, + messages=[{"content": condition, "role": "user"}], ) - for choice in response['choices']: - responses.append(choice['message']['content']) + for choice in response["choices"]: + responses.append(choice["message"]["content"]) return responses @abc.abstractmethod @@ -102,8 +111,8 @@ def chat_generate( max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, - num_return_sequences: int = 1 - ) -> list[list[dict[str, str]]]: + num_return_sequences: int = 1, + ) -> list[list[dict[str, str]]]: """Generate responses to multiple prompts using the batch_completion function. This method sends multiple prompts to the language model API and retrieves @@ -127,27 +136,32 @@ def chat_generate( max_length (int): The maximum length of the output sequence, (defaults to model max). max_new_tokens (float): The maximum number of tokens to generate in the chat completion. - temperature (float): The sampling temperature to be used, between 0 and 2. + temperature (float): The sampling temperature to be used, between 0 and 2. num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: list[list[dict[str, str]]]: list of chat contexts with the generated responses. """ if do_sample: - warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) + warnings.warn( + "A non-default value for 'do_sample' was provided.", UserWarning + ) if max_length is not None: - warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) - + warnings.warn( + "A non-default value for 'max_length' was provided.", UserWarning + ) + responses = batch_completion( - model = self.model_name, - temperature = temperature, - max_tokens = max_new_tokens, - n = num_return_sequences, - api_base = self.api_base, - messages=messages + model=self.model_name, + temperature=temperature, + max_tokens=max_new_tokens, + n=num_return_sequences, + api_base=self.api_base, + messages=[messages], ) + responses = [r["message"]["content"] for r in responses[0]["choices"]] return [messages + [{"role": "assistant", "content": r}] for r in responses] - + @abc.abstractmethod def set_seed(self, seed: int) -> None: """Set the seed for the language model.