Skip to content

Commit

Permalink
Add support for API-based LLMs in biasmonkey notebook (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
zaidsheikh authored Sep 27, 2024
1 parent 9dae41c commit eacaa4c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 40 deletions.
35 changes: 28 additions & 7 deletions examples/bias_monkey/bias_monkey.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from glob import glob\n",
"\n",
"from llments.lm.base.hugging_face import HuggingFaceLM\n",
"from llments.lm.base.api import APIBasedLM\n",
"import torch, gc\n",
"from pathlib import Path\n",
"from bias_monkey_utils import (\n",
Expand Down Expand Up @@ -61,23 +63,42 @@
" \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n",
" \"llama2-70b\": \"meta-llama/Llama-2-70b-hf\",\n",
" \"llama2-70b-chat\": \"meta-llama/Llama-2-70b-chat-hf\",\n",
" \"gpt-4o-mini-2024-07-18\": \"openai/neulab/gpt-4o-mini-2024-07-18\",\n",
"}\n",
"\n",
"\n",
"def is_chat_model(model):\n",
" return \"chat\" in model or \"gpt\" in model\n",
"\n",
"\n",
"for model in model_paths:\n",
" print(f\"Loading {model}\")\n",
" lm = HuggingFaceLM(model_paths[model], device=device)\n",
" if model_paths[model].startswith(\"openai/neulab/\"):\n",
" assert os.environ.get(\"OPENAI_API_KEY\"), \"Please set OPENAI_API_KEY\"\n",
" assert os.environ.get(\"LITELLM_API_BASE\"), \"Please set LITELLM_API_BASE\"\n",
" lm = APIBasedLM(\n",
" model_name=model_paths[model], api_base=os.environ[\"LITELLM_API_BASE\"]\n",
" )\n",
" else:\n",
" lm = HuggingFaceLM(model_paths[model], device=device)\n",
" for csv_file in sorted(glob(\"BiasMonkey/prompts/*.csv\")):\n",
" print(f\"Processing {csv_file}\")\n",
" filename = os.path.basename(csv_file.removesuffix(\".csv\"))\n",
" bias_type, perturbation = filename.split(\"-\")\n",
" if \"-\" not in filename:\n",
" bias_type, perturbation = filename, None\n",
" else:\n",
" bias_type, perturbation = filename.split(\"-\", 1)\n",
" if bias_type not in bias_types:\n",
" continue\n",
" df = generate_survey_responses(\n",
" model=lm,\n",
" prompts_file=csv_file,\n",
" bias_type=bias_type,\n",
" perturbation=perturbation,\n",
" output_path=f\"results/{model}/{filename}.pickle\",\n",
" output_csv=f\"results/{model}/csv/{filename}.csv\",\n",
" is_chat_model=\"chat\" in model,\n",
" seed=1,\n",
" bias_type=bias_type,\n",
" perturbation=perturbation,\n",
" is_chat_model=is_chat_model(model),\n",
" seed=None if is_chat_model(model) else 1,\n",
" num_samples=50,\n",
" batch_size=25,\n",
" overwrite=True,\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/bias_monkey/bias_monkey_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def format_df(
def generate_survey_responses(
model: LanguageModel,
prompts_file: str,
bias_type: str,
perturbation: str,
output_path: str,
output_csv: str,
bias_type: str,
perturbation: str | None = None,
is_chat_model: bool = True,
seed: int | None = None,
num_samples: int = 50,
Expand All @@ -256,7 +256,7 @@ def generate_survey_responses(
model: The language model.
prompts_file: The csv file containing the prompts.
bias_type: one of ["acquiescence", "allow_forbid", "odd_even", "response_order", "opinion_float"]
perturbation: one of ["key_typo", "middle_random", "letter_swap"]
perturbation: one of ["key_typo", "middle_random", "letter_swap"] or None
output_path: output path (pickle file).
output_csv: output csv file.
is_chat_model: Whether the model is a chat model.
Expand Down
74 changes: 44 additions & 30 deletions llments/lm/base/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Base class for API-Based Language Models."""

import os
import abc
import os
import warnings

from litellm import ModelResponse, batch_completion, completion

from llments.lm.lm import LanguageModel
from litellm import completion, batch_completion, ModelResponse


class APIBasedLM(LanguageModel):
"""Base class for API-Based Language Models.
Expand Down Expand Up @@ -54,8 +57,8 @@ def generate(
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1
) -> list[str]:
num_return_sequences: int = 1,
) -> list[str]:
"""Generate a response based on the given prompt.
This method sends a prompt to the language model API and retrieves
Expand All @@ -68,30 +71,36 @@ def generate(
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
str: Sampled output sequences from the language model.
"""
if condition is not None:
warnings.warn("A non-default value for 'condition' was provided.", UserWarning)
warnings.warn(
"A non-default value for 'condition' was provided.", UserWarning
)
if do_sample:
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
warnings.warn(
"A non-default value for 'do_sample' was provided.", UserWarning
)
if max_length is not None:
warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)

responses = []
response = completion(
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
api_base = self.api_base,
messages=[{"content": condition, "role": "user"}]
model=self.model_name,
temperature=temperature,
max_tokens=max_new_tokens,
n=num_return_sequences,
api_base=self.api_base,
messages=[{"content": condition, "role": "user"}],
)
for choice in response['choices']:
responses.append(choice['message']['content'])
for choice in response["choices"]:
responses.append(choice["message"]["content"])
return responses

@abc.abstractmethod
Expand All @@ -102,8 +111,8 @@ def chat_generate(
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1
) -> list[list[dict[str, str]]]:
num_return_sequences: int = 1,
) -> list[list[dict[str, str]]]:
"""Generate responses to multiple prompts using the batch_completion function.
This method sends multiple prompts to the language model API and retrieves
Expand All @@ -127,27 +136,32 @@ def chat_generate(
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
list[list[dict[str, str]]]: list of chat contexts with the generated responses.
"""
if do_sample:
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
warnings.warn(
"A non-default value for 'do_sample' was provided.", UserWarning
)
if max_length is not None:
warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)

responses = batch_completion(
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
api_base = self.api_base,
messages=messages
model=self.model_name,
temperature=temperature,
max_tokens=max_new_tokens,
n=num_return_sequences,
api_base=self.api_base,
messages=[messages],
)
responses = [r["message"]["content"] for r in responses[0]["choices"]]
return [messages + [{"role": "assistant", "content": r}] for r in responses]

@abc.abstractmethod
def set_seed(self, seed: int) -> None:
"""Set the seed for the language model.
Expand Down

0 comments on commit eacaa4c

Please sign in to comment.