From 070d01f49eac9803b95da711a645993f3bfdb599 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Mon, 28 Oct 2024 23:44:59 -0700 Subject: [PATCH] Port to LiteLLM --- .github/tests/lm_tests.py | 7 +- examples/op_examples/agg.py | 4 +- examples/op_examples/cluster.py | 4 +- examples/op_examples/filter.py | 4 +- examples/op_examples/filter_cascade.py | 6 +- examples/op_examples/join.py | 4 +- examples/op_examples/map.py | 4 +- examples/op_examples/map_fewshot.py | 4 +- examples/op_examples/partition.py | 4 +- examples/op_examples/search.py | 4 +- examples/op_examples/sim_join.py | 4 +- examples/op_examples/top_k.py | 4 +- examples/provider_examples/oai.py | 4 +- examples/provider_examples/ollama.py | 4 +- examples/provider_examples/vllm.py | 4 +- lotus/models/__init__.py | 2 - lotus/models/lm.py | 129 +++++----- lotus/models/openai_model.py | 325 ------------------------- lotus/sem_ops/cascade_utils.py | 39 +-- lotus/sem_ops/sem_agg.py | 12 +- lotus/sem_ops/sem_extract.py | 12 +- lotus/sem_ops/sem_filter.py | 55 +++-- lotus/sem_ops/sem_join.py | 26 +- lotus/sem_ops/sem_map.py | 11 +- lotus/sem_ops/sem_topk.py | 14 +- lotus/types.py | 19 +- 26 files changed, 201 insertions(+), 508 deletions(-) delete mode 100644 lotus/models/openai_model.py diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index af23d1c7..35a0a5c9 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -2,7 +2,7 @@ import pytest import lotus -from lotus.models import OpenAIModel +from lotus.models import LM # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") @@ -11,8 +11,8 @@ @pytest.fixture def setup_models(): # Setup GPT models - gpt_4o_mini = OpenAIModel(model="gpt-4o-mini") - gpt_4o = OpenAIModel(model="gpt-4o") + gpt_4o_mini = LM(model="gpt-4o-mini") + gpt_4o = LM(model="gpt-4o") return gpt_4o_mini, gpt_4o @@ -57,7 +57,6 @@ def test_filter_cascade(setup_models): "Everything is going as planned, couldn't be happier.", "Feeling super motivated and ready to take on challenges!", "I appreciate all the small things that bring me joy.", - # Negative examples "I am very sad.", "Today has been really tough; I feel exhausted.", diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index add1711e..6f6e14b0 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 7bcc307b..2d7af6f1 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index b89f74f2..ee96e876 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 5af900b2..583fd78b 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -gpt_35_turbo = OpenAIModel("gpt-3.5-turbo") -gpt_4o = OpenAIModel("gpt-4o") +gpt_35_turbo = LM("gpt-3.5-turbo") +gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_35_turbo) data = { diff --git a/examples/op_examples/join.py b/examples/op_examples/join.py index 2c850497..3b8fb30f 100644 --- a/examples/op_examples/join.py +++ b/examples/op_examples/join.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map.py b/examples/op_examples/map.py index 6323899d..a3ea765b 100644 --- a/examples/op_examples/map.py +++ b/examples/op_examples/map.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map_fewshot.py b/examples/op_examples/map_fewshot.py index fea45dc8..b3bf07fb 100644 --- a/examples/op_examples/map_fewshot.py +++ b/examples/op_examples/map_fewshot.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index ca42d171..91fa185b 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel(max_tokens=2048) +lm = LM(max_tokens=2048) rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index b7ebf67d..21c7fb5e 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import CrossEncoderModel, E5Model, OpenAIModel +from lotus.models import LM, CrossEncoderModel, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() reranker = CrossEncoderModel() diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 7d3981ed..beaea582 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 2930e305..8ffaf7b3 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/provider_examples/oai.py b/examples/provider_examples/oai.py index b89f74f2..ee96e876 100644 --- a/examples/provider_examples/oai.py +++ b/examples/provider_examples/oai.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/provider_examples/ollama.py b/examples/provider_examples/ollama.py index 727add7d..8eb967ad 100644 --- a/examples/provider_examples/ollama.py +++ b/examples/provider_examples/ollama.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel( +lm = LM( api_base="http://localhost:11434/v1", model="llama3.2", hf_name="meta-llama/Llama-3.2-3B-Instruct", diff --git a/examples/provider_examples/vllm.py b/examples/provider_examples/vllm.py index 76a46884..70975f95 100644 --- a/examples/provider_examples/vllm.py +++ b/examples/provider_examples/vllm.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel( +lm = LM( model="meta-llama/Meta-Llama-3.1-70B-Instruct", api_base="http://localhost:8000/v1", provider="vllm", diff --git a/lotus/models/__init__.py b/lotus/models/__init__.py index 194d7259..4477c6e2 100644 --- a/lotus/models/__init__.py +++ b/lotus/models/__init__.py @@ -2,12 +2,10 @@ from lotus.models.cross_encoder_model import CrossEncoderModel from lotus.models.e5_model import E5Model from lotus.models.lm import LM -from lotus.models.openai_model import OpenAIModel from lotus.models.reranker import Reranker from lotus.models.rm import RM __all__ = [ - "OpenAIModel", "E5Model", "ColBERTv2Model", "CrossEncoderModel", diff --git a/lotus/models/lm.py b/lotus/models/lm.py index d39aea22..e878bb4f 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,62 +1,67 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class LM(ABC): - """Abstract class for language models.""" - - def _init__(self): - pass - - @abstractmethod - def count_tokens(self, prompt: str | list) -> int: - """ - Counts the number of tokens in the given prompt. - - Args: - prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages. - - Returns: - int: The number of tokens in the prompt. - """ - pass - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - """ - Formats the logprobs for the cascade. - - Args: - logprobs (list): The logprobs to format. - - Returns: - tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences. - """ - pass - - @abstractmethod - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - """Invoke the LLM. - - Args: - messages_batch (list | list[list]): Either one prompt or a list of prompts in message format. - kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - pass - - @property - @abstractmethod - def max_ctx_len(self) -> int: - """The maximum context length of the LLM.""" - pass - - @property - @abstractmethod - def max_tokens(self) -> int: - """The maximum number of tokens that can be generated by the LLM.""" - pass +import numpy as np +from litellm import batch_completion, token_counter +from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse + +from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade + + +class LM: + def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs): + self.model = model + self.max_ctx_len = max_ctx_len + self.max_tokens = max_tokens + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + self.history = [] + + def __call__(self, messages=None, **kwargs) -> LMOutput: + kwargs = {**self.kwargs, **kwargs} + if kwargs.get("logprobs", False): + kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10) + + responses: list[ModelResponse] = batch_completion(model=self.model, messages=messages, **kwargs) + outputs = [self._get_top_choice(resp) for resp in responses] + logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None + + return LMOutput(outputs=outputs, logprobs=logprobs) + + def _get_top_choice(self, response: ModelResponse) -> str: + return response.choices[0].message.content + + def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: + logprobs = response.choices[0].logprobs["content"] + return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs] + + def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: + all_tokens = [] + all_confidences = [] + for resp in range(len(logprobs)): + tokens = [logprob.token for logprob in logprobs[resp]] + confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]] + all_tokens.append(tokens) + all_confidences.append(confidences) + return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) + + def format_logprobs_for_filter_cascade( + self, logprobs: list[list[ChatCompletionTokenLogprob]] + ) -> LogprobsForFilterCascade: + all_tokens = [] + all_confidences = [] + all_true_probs = [] + + for resp in range(len(logprobs)): + all_tokens.append([logprob.token for logprob in logprobs[resp]]) + all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]]) + top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]} + true_prob, false_prob = 0, 0 + if top_logprobs and "True" in top_logprobs and "False" in top_logprobs: + true_prob = np.exp(top_logprobs["True"]) + false_prob = np.exp(top_logprobs["False"]) + all_true_probs.append(true_prob / (true_prob + false_prob)) + else: + all_true_probs.append(1 if "True" in top_logprobs else 0) + return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs) + + def count_tokens(self, messages: list[dict[str, str]] | str) -> int: + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + return token_counter(model=self.model, messages=messages) diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py deleted file mode 100644 index 57fb20eb..00000000 --- a/lotus/models/openai_model.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import threading -from typing import Any - -import backoff -import numpy as np -import openai -import tiktoken -from openai import OpenAI -from transformers import AutoTokenizer - -import lotus -from lotus.models.lm import LM - -ERRORS = (openai.RateLimitError, openai.APIError) - - -def backoff_hdlr(details): - """Handler from https://pypi.org/project/backoff/""" - print( - "Backing off {wait:0.1f} seconds after {tries} tries " - "calling function {target} with kwargs " - "{kwargs}".format(**details), - ) - - -class OpenAIModel(LM): - """Wrapper around OpenAI, Databricks, and vLLM OpenAI server - - Args: - model (str): The name of the model to use. - api_key (str | None): An API key (e.g. from OpenAI or Databricks). - api_base (str | None): The endpoint of the server. - provider (str): Either openai, dbrx, or vllm. - max_batch_size (int): The maximum batch size for the model. - max_ctx_len (int): The maximum context length for the model. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - """ - - def __init__( - self, - model: str = "gpt-4o-mini", - hf_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, - provider: str = "openai", - max_batch_size: int = 64, - max_ctx_len: int = 4096, - **kwargs: dict[str, Any], - ): - super().__init__() - self.provider = provider - self.use_chat = provider in ["openai", "dbrx", "ollama"] - self.max_batch_size = max_batch_size - self.hf_name = hf_name if hf_name is not None else model - self.__dict__["max_ctx_len"] = max_ctx_len - - self.kwargs = { - "model": model, - "temperature": 0.0, - "max_tokens": 512, - "top_p": 1, - "n": 1, - **kwargs, - } - - api_key = api_key or os.environ.get("OPENAI_API_KEY", "None") - self.client = OpenAI(api_key=api_key if api_key else "None", base_url=api_base) - - # TODO: Refactor this - if self.provider == "openai": - self.tokenizer = tiktoken.encoding_for_model(model) - else: - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name) - - def handle_chat_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single chat request to OpenAI server. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch (just one in this case). If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if kwargs.get("logprobs", False): - kwargs["top_logprobs"] = 10 - - kwargs = {**self.kwargs, **kwargs} - kwargs["messages"] = messages - response = self.chat_request(**kwargs) - - choices = response["choices"] - completions = [c["message"]["content"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - def handle_completion_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle a potentially batched completions request to OpenAI server. - - Args: - messages_batch (list): A list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if not isinstance(messages[0], list): - prompt = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)] - else: - prompt = [ - self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in messages - ] - - kwargs = {**self.kwargs, **kwargs} - kwargs["prompt"] = prompt - if kwargs.get("logprobs", False): - kwargs["logprobs"] = 10 - response = self.completion_request(**kwargs) - - choices = response["choices"] - completions = [c["text"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - @backoff.on_exception( - backoff.expo, - ERRORS, - max_time=1000, - on_backoff=backoff_hdlr, - ) - def request(self, messages: list, **kwargs: dict[str, Any]) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single request to OpenAI server. Decides whether chat or completion endpoint is necessary. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - A list of text outputs for each prompt in the batch (just one in this case). - If logprobs is specified in the keyword arguments, hen a list of logprobs is also returned (also of size one). - """ - if self.use_chat: - return self.handle_chat_request(messages, **kwargs) - else: - return self.handle_completion_request(messages, **kwargs) - - def batched_chat_request( - self, messages_batch: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle batched chat request to OpenAI server. - - Args: - messages_batch (list): Either one prompt or a list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - - batch_size = len(messages_batch) - text_ret = [None] * batch_size - logprobs_ret = [None] * batch_size - threads = [] - - def thread_function(idx, messages, kwargs): - text = self(messages, **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = text - logprobs_ret[idx] = logprobs[0] - text_ret[idx] = text[0] - - for idx, messages in enumerate(messages_batch): - thread = threading.Thread(target=thread_function, args=(idx, messages, kwargs)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - - return text_ret - - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - lotus.logger.debug(f"OpenAIModel.__call__ messages_batch: {messages_batch}") - lotus.logger.debug(f"OpenAIModel.__call__ kwargs: {kwargs}") - # Bakes max batch size into model call. # TODO: Figure out less hacky way to do this. - if isinstance(messages_batch[0], list) and len(messages_batch) > self.max_batch_size: - text_ret = [] - logprobs_ret = [] - for i in range(0, len(messages_batch), self.max_batch_size): - res = self(messages_batch[i : i + self.max_batch_size], **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = res - logprobs_ret.extend(logprobs) - else: - text = res - text_ret.extend(text) - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - return text_ret - - if self.use_chat and isinstance(messages_batch[0], list): - return self.batched_chat_request(messages_batch, **kwargs) - - return self.request(messages_batch, **kwargs) - - def count_tokens(self, prompt: str | list) -> int: - if isinstance(prompt, str): - if self.provider != "openai": - return len(self.tokenizer(prompt)["input_ids"]) - - return len(self.tokenizer.encode(prompt)) - else: - if self.provider != "openai": - return len(self.tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)) - - return sum(len(self.tokenizer.encode(message["content"])) for message in prompt) - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences - - def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - all_true_probs = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - top_logprobs = logprobs[idx]["top_logprobs"][0] - if 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]} - - true_prob, false_prob = 0, 0 - if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences, all_true_probs - - def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send chat request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI chat completion response. - """ - return self.client.chat.completions.create(**kwargs).model_dump() - - def completion_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send completion request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI completion response. - """ - return self.client.completions.create(**kwargs).model_dump() - - @property - def max_tokens(self) -> int: - return self.kwargs["max_tokens"] - - @property - def max_ctx_len(self) -> int: - return self.__dict__["max_ctx_len"] diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 3302a493..7aab77db 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -12,43 +12,47 @@ def importance_sampling( w = np.sqrt(proxy_scores) w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores) indices = np.arange(len(proxy_scores)) - sample_size = (int) (sample_percentage * len(proxy_scores)) + sample_size = (int)(sample_percentage * len(proxy_scores)) sample_indices = np.random.choice(indices, sample_size, p=w) - correction_factors = (1/len(proxy_scores)) / w + correction_factors = (1 / len(proxy_scores)) / w return sample_indices, correction_factors + def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = 50 quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) + true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles true_probs = np.clip(true_probs, 0, 1) return true_probs + def learn_cascade_thresholds( proxy_scores: list[float], oracle_outputs: list[float], sample_correction_factors: list[float], recall_target: float, precision_target: float, - delta: float + delta: float, ) -> tuple[tuple[float, float], int]: - """Learns cascade thresholds given targets and proxy scores, + """Learns cascade thresholds given targets and proxy scores, oracle outputs over the sample, and correction factors for the sample.""" def UB(mean, std_dev, s, delta): - return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + return mean + (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) def LB(mean, std_dev, s, delta): - return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + return mean - (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold] total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs) - recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct + recall = ( + sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle) + ) / total_correct return recall def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: @@ -65,10 +69,12 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True) sample_size = len(sorted_pairs) - best_combination = (1,0) # initial tau_+, tau_- + best_combination = (1, 0) # initial tau_+, tau_- # Find tau_negative based on recall - tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target) + tau_neg_0 = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target + ) best_combination = (best_combination[0], tau_neg_0) # Do a statistical correction to get a new target recall @@ -80,9 +86,13 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: mean_z2 = np.mean(Z2) if Z2 else 0 std_z2 = np.std(Z2) if Z2 else 0 - corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2)) + corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta / 2) / ( + UB(mean_z1, std_z1, sample_size, delta / 2) + LB(mean_z2, std_z2, sample_size, delta / 2) + ) corrected_recall_target = min(1, corrected_recall_target) - tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target) + tau_neg_prime = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target + ) best_combination = (best_combination[0], tau_neg_prime) # Do a statistical correction to get a target satisfying precision @@ -92,7 +102,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold] mean_z = np.mean(Z) if Z else 0 std_z = np.std(Z) if Z else 0 - p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs)) + p_l = LB(mean_z, std_z, len(Z), delta / len(sorted_pairs)) if p_l > precision_target: candidate_thresholds.append(possible_threshold) @@ -105,6 +115,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: return best_combination, oracle_calls + def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: true_score = np.clip(true_score, 0, 1) - return true_score \ No newline at end of file + return true_score diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 6d77f8fe..6fd9e8b0 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -2,9 +2,9 @@ import pandas as pd -import lotus +import lotus.models from lotus.templates import task_instructions -from lotus.types import SemanticAggOutput +from lotus.types import LMOutput, SemanticAggOutput def sem_agg( @@ -108,13 +108,9 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) new_partition_ids.append(cur_partition_id) - result = model(batch) + lm_output: LMOutput = model(batch) - # TODO: this is a weird hack for model typing - if isinstance(result, tuple): - summaries, _ = result - else: - summaries = result + summaries = lm_output.outputs partition_ids = new_partition_ids new_partition_ids = [] diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 6336b196..82e82a98 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput from .postprocessors import extract_postprocess @@ -36,15 +36,11 @@ def sem_extract( inputs.append(prompt) # call model - raw_outputs = model(inputs) - if isinstance(raw_outputs, tuple): - raw_outputs, _ = raw_outputs - else: - assert isinstance(raw_outputs, list) + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"quotes: {postprocess_output.quotes}") diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index aafd6aa6..76e8b9b4 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -4,9 +4,9 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticFilterOutput +from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput -from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds +from .cascade_utils import importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -45,20 +45,17 @@ def sem_filter( lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) kwargs: dict[str, Any] = {"logprobs": logprobs} - res = model(inputs, **kwargs) - if logprobs: - assert isinstance(res, tuple) - raw_outputs, raw_logprobs = res - else: - assert isinstance(res, list) - raw_outputs = res - - postprocess_output = filter_postprocess(raw_outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]) + lm_output: LMOutput = model(inputs, **kwargs) + + postprocess_output = filter_postprocess( + lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") - return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=raw_logprobs if logprobs else None) + return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None) + def learn_filter_cascade_thresholds( sample_df_txt: str, @@ -75,8 +72,8 @@ def learn_filter_cascade_thresholds( cot_reasoning: list | None = None, strategy: str | None = None, ) -> tuple[float, float]: - """Automatically learns the cascade thresholds for a cascade - filter given a sample of data and doing a search across threshold + """Automatically learns the cascade thresholds for a cascade + filter given a sample of data and doing a search across threshold to see what threshold gives the best accuracy.""" try: @@ -97,7 +94,7 @@ def learn_filter_cascade_thresholds( sample_correction_factors=sample_correction_factors, recall_target=recall_target, precision_target=precision_target, - delta=delta + delta=delta, ) lotus.logger.info(f"Learned cascade thresholds: {best_combination}") @@ -107,6 +104,7 @@ def learn_filter_cascade_thresholds( lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") return None + @pd.api.extensions.register_dataframe_accessor("sem_filter") class SemFilterDataframe: """DataFrame accessor for semantic filter.""" @@ -198,14 +196,16 @@ def __call__( if helper_strategy == "cot": helper_cot_reasoning = examples["Reasoning"].tolist() - + if learn_cascade_threshold_sample_percentage and lotus.settings.helper_lm: if helper_strategy == "cot": lotus.logger.error("CoT not supported for helper models in cascades.") raise Exception if recall_target is None or precision_target is None or failure_probability is None: - lotus.logger.error("Recall target, precision target, and confidence need to be specified for learned thresholds.") + lotus.logger.error( + "Recall target, precision target, and confidence need to be specified for learned thresholds." + ) raise Exception # Run small LM and get logits @@ -221,11 +221,14 @@ def __call__( strategy=helper_strategy, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs - _, _, helper_true_probs = lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) - - helper_true_probs = calibrate_llm_logprobs(helper_true_probs) + formatted_helper_logprobs: LogprobsForFilterCascade = ( + lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) + ) + helper_true_probs = formatted_helper_logprobs.true_probs - sample_indices, correction_factors = importance_sampling(helper_true_probs, learn_cascade_threshold_sample_percentage) + sample_indices, correction_factors = importance_sampling( + helper_true_probs, learn_cascade_threshold_sample_percentage + ) sample_df = self._obj.loc[sample_indices] sample_df_txt = task_instructions.df2text(sample_df, col_li) sample_helper_true_probs = [helper_true_probs[i] for i in sample_indices] @@ -238,7 +241,7 @@ def __call__( default=default, recall_target=recall_target, precision_target=precision_target, - delta=failure_probability/2, + delta=failure_probability / 2, helper_true_probs=sample_helper_true_probs, sample_correction_factors=sample_correction_factors, examples_df_txt=examples_df_txt, @@ -261,7 +264,13 @@ def __call__( true_prob = helper_true_probs[idx_i] if true_prob >= pos_cascade_threshold or true_prob <= neg_cascade_threshold: high_conf_idxs.add(idx_i) - helper_outputs[idx_i] = True if true_prob >= pos_cascade_threshold else False if true_prob <= neg_cascade_threshold else helper_outputs[idx_i] + helper_outputs[idx_i] = ( + True + if true_prob >= pos_cascade_threshold + else False + if true_prob <= neg_cascade_threshold + else helper_outputs[idx_i] + ) lotus.logger.info(f"Num routed to smaller model: {len(high_conf_idxs)}") stats["num_routed_to_helper_model"] = len(high_conf_idxs) diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 05f6bacc..0d01cc55 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -1,6 +1,5 @@ from typing import Any -import numpy as np import pandas as pd import lotus @@ -136,24 +135,17 @@ def sem_join_cascade( assert helper_logprobs is not None high_conf_idxs = set() - for idx_i in range(len(helper_outputs)): - tokens: list[str] - confidences: np.ndarray[Any, np.dtype[np.float64]] - # Get the logprobs - if lotus.settings.helper_lm.provider == "vllm": - tokens = helper_logprobs[idx_i]["tokens"] - confidences = np.exp(helper_logprobs[idx_i]["token_logprobs"]) - elif lotus.settings.helper_lm.provider == "openai": - content: list[dict[str, Any]] = helper_logprobs[idx_i]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) + # Get the logprobs in a standardized format + formatted_logprobs = lotus.settings.helper_lm.format_logprobs_for_cascade(helper_logprobs) + tokens, confidences = formatted_logprobs.tokens, formatted_logprobs.confidences + for doc_idx in range(len(helper_outputs)): # Find where true/false is said and look at confidence - for idx_j in range(len(tokens) - 1, -1, -1): - if tokens[idx_j].strip(" \n").lower() in ["true", "false"]: - conf = confidences[idx_j] - if conf >= cascade_threshold: - high_conf_idxs.add(idx_i) + for token_idx in range(len(tokens[doc_idx]) - 1, -1, -1): + if tokens[doc_idx][token_idx].strip(" \n").lower() in ["true", "false"]: + confidence = confidences[doc_idx][token_idx] + if confidence >= cascade_threshold: + high_conf_idxs.add(doc_idx) # Send low confidence samples to large LM low_conf_idxs = sorted([i for i in range(len(helper_outputs)) if i not in high_conf_idxs]) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 7ee84e62..9074c094 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput from .postprocessors import map_postprocess @@ -45,14 +45,11 @@ def sem_map( inputs.append(prompt) # call model - raw_outputs = model(inputs) - assert isinstance(raw_outputs, list) and all( - isinstance(item, str) for item in raw_outputs - ), "Model must return a list of strings" + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs, strategy in ["cot", "zs-cot"]) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"]) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index a2b98a44..4db77a63 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -7,7 +7,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticTopKOutput +from lotus.types import LMOutput, SemanticTopKOutput def get_match_prompt_binary( @@ -65,8 +65,8 @@ def compare_batch_binary( match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) tokens += lotus.settings.lm.count_tokens(match_prompts[-1]) - results = lotus.settings.lm(match_prompts) - results = list(map(parse_ans_binary, results)) + results: LMOutput = lotus.settings.lm(match_prompts) + results = list(map(parse_ans_binary, results.outputs)) return results, tokens @@ -109,8 +109,8 @@ def compare_batch_binary_cascade( large_match_prompts.append(match_prompts[i]) large_tokens += lotus.settings.lm.count_tokens(large_match_prompts[-1]) - results = lotus.settings.lm(large_match_prompts) - for idx, res in enumerate(results): + results: LMOutput = lotus.settings.lm(large_match_prompts) + for idx, res in enumerate(results.outputs): new_idx = low_conf_idxs[idx] parsed_res = parse_ans_binary(res) parsed_results[new_idx] = parsed_res @@ -268,8 +268,8 @@ def __lt__(self, other: "HeapDoc") -> bool: prompt = get_match_prompt_binary(self.doc, other.doc, self.user_instruction, strategy=self.strategy) HeapDoc.num_calls += 1 HeapDoc.total_tokens += lotus.settings.lm.count_tokens(prompt) - result = lotus.settings.lm(prompt) - return parse_ans_binary(result[0]) + result: LMOutput = lotus.settings.lm([prompt]) + return parse_ans_binary(result.outputs[0]) def llm_heapsort( diff --git a/lotus/types.py b/lotus/types.py index a33339d5..7852754f 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -1,5 +1,6 @@ from typing import Any +from litellm.types.utils import ChatCompletionTokenLogprob from pydantic import BaseModel @@ -9,10 +10,11 @@ class StatsMixin(BaseModel): # TODO: Figure out better logprobs type class LogprobsMixin(BaseModel): - logprobs: list[dict[str, Any]] | None = None + # for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob + logprobs: list[list[ChatCompletionTokenLogprob]] | None = None -class SemanticMapPostprocessOutput(StatsMixin, LogprobsMixin): +class SemanticMapPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[str] explanations: list[str | None] @@ -55,3 +57,16 @@ class SemanticJoinOutput(StatsMixin): class SemanticTopKOutput(StatsMixin): indexes: list[int] + + +class LMOutput(LogprobsMixin): + outputs: list[str] + + +class LogprobsForCascade(BaseModel): + tokens: list[list[str]] + confidences: list[list[float]] + + +class LogprobsForFilterCascade(LogprobsForCascade): + true_probs: list[float]