From 070d01f49eac9803b95da711a645993f3bfdb599 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Mon, 28 Oct 2024 23:44:59 -0700 Subject: [PATCH 01/28] Port to LiteLLM --- .github/tests/lm_tests.py | 7 +- examples/op_examples/agg.py | 4 +- examples/op_examples/cluster.py | 4 +- examples/op_examples/filter.py | 4 +- examples/op_examples/filter_cascade.py | 6 +- examples/op_examples/join.py | 4 +- examples/op_examples/map.py | 4 +- examples/op_examples/map_fewshot.py | 4 +- examples/op_examples/partition.py | 4 +- examples/op_examples/search.py | 4 +- examples/op_examples/sim_join.py | 4 +- examples/op_examples/top_k.py | 4 +- examples/provider_examples/oai.py | 4 +- examples/provider_examples/ollama.py | 4 +- examples/provider_examples/vllm.py | 4 +- lotus/models/__init__.py | 2 - lotus/models/lm.py | 129 +++++----- lotus/models/openai_model.py | 325 ------------------------- lotus/sem_ops/cascade_utils.py | 39 +-- lotus/sem_ops/sem_agg.py | 12 +- lotus/sem_ops/sem_extract.py | 12 +- lotus/sem_ops/sem_filter.py | 55 +++-- lotus/sem_ops/sem_join.py | 26 +- lotus/sem_ops/sem_map.py | 11 +- lotus/sem_ops/sem_topk.py | 14 +- lotus/types.py | 19 +- 26 files changed, 201 insertions(+), 508 deletions(-) delete mode 100644 lotus/models/openai_model.py diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index af23d1c7..35a0a5c9 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -2,7 +2,7 @@ import pytest import lotus -from lotus.models import OpenAIModel +from lotus.models import LM # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") @@ -11,8 +11,8 @@ @pytest.fixture def setup_models(): # Setup GPT models - gpt_4o_mini = OpenAIModel(model="gpt-4o-mini") - gpt_4o = OpenAIModel(model="gpt-4o") + gpt_4o_mini = LM(model="gpt-4o-mini") + gpt_4o = LM(model="gpt-4o") return gpt_4o_mini, gpt_4o @@ -57,7 +57,6 @@ def test_filter_cascade(setup_models): "Everything is going as planned, couldn't be happier.", "Feeling super motivated and ready to take on challenges!", "I appreciate all the small things that bring me joy.", - # Negative examples "I am very sad.", "Today has been really tough; I feel exhausted.", diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index add1711e..6f6e14b0 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 7bcc307b..2d7af6f1 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index b89f74f2..ee96e876 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 5af900b2..583fd78b 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -gpt_35_turbo = OpenAIModel("gpt-3.5-turbo") -gpt_4o = OpenAIModel("gpt-4o") +gpt_35_turbo = LM("gpt-3.5-turbo") +gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_35_turbo) data = { diff --git a/examples/op_examples/join.py b/examples/op_examples/join.py index 2c850497..3b8fb30f 100644 --- a/examples/op_examples/join.py +++ b/examples/op_examples/join.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map.py b/examples/op_examples/map.py index 6323899d..a3ea765b 100644 --- a/examples/op_examples/map.py +++ b/examples/op_examples/map.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map_fewshot.py b/examples/op_examples/map_fewshot.py index fea45dc8..b3bf07fb 100644 --- a/examples/op_examples/map_fewshot.py +++ b/examples/op_examples/map_fewshot.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index ca42d171..91fa185b 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel(max_tokens=2048) +lm = LM(max_tokens=2048) rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index b7ebf67d..21c7fb5e 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import CrossEncoderModel, E5Model, OpenAIModel +from lotus.models import LM, CrossEncoderModel, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() reranker = CrossEncoderModel() diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 7d3981ed..beaea582 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model, OpenAIModel +from lotus.models import LM, E5Model -lm = OpenAIModel() +lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 2930e305..8ffaf7b3 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/provider_examples/oai.py b/examples/provider_examples/oai.py index b89f74f2..ee96e876 100644 --- a/examples/provider_examples/oai.py +++ b/examples/provider_examples/oai.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) data = { diff --git a/examples/provider_examples/ollama.py b/examples/provider_examples/ollama.py index 727add7d..8eb967ad 100644 --- a/examples/provider_examples/ollama.py +++ b/examples/provider_examples/ollama.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel( +lm = LM( api_base="http://localhost:11434/v1", model="llama3.2", hf_name="meta-llama/Llama-3.2-3B-Instruct", diff --git a/examples/provider_examples/vllm.py b/examples/provider_examples/vllm.py index 76a46884..70975f95 100644 --- a/examples/provider_examples/vllm.py +++ b/examples/provider_examples/vllm.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM -lm = OpenAIModel( +lm = LM( model="meta-llama/Meta-Llama-3.1-70B-Instruct", api_base="http://localhost:8000/v1", provider="vllm", diff --git a/lotus/models/__init__.py b/lotus/models/__init__.py index 194d7259..4477c6e2 100644 --- a/lotus/models/__init__.py +++ b/lotus/models/__init__.py @@ -2,12 +2,10 @@ from lotus.models.cross_encoder_model import CrossEncoderModel from lotus.models.e5_model import E5Model from lotus.models.lm import LM -from lotus.models.openai_model import OpenAIModel from lotus.models.reranker import Reranker from lotus.models.rm import RM __all__ = [ - "OpenAIModel", "E5Model", "ColBERTv2Model", "CrossEncoderModel", diff --git a/lotus/models/lm.py b/lotus/models/lm.py index d39aea22..e878bb4f 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,62 +1,67 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class LM(ABC): - """Abstract class for language models.""" - - def _init__(self): - pass - - @abstractmethod - def count_tokens(self, prompt: str | list) -> int: - """ - Counts the number of tokens in the given prompt. - - Args: - prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages. - - Returns: - int: The number of tokens in the prompt. - """ - pass - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - """ - Formats the logprobs for the cascade. - - Args: - logprobs (list): The logprobs to format. - - Returns: - tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences. - """ - pass - - @abstractmethod - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - """Invoke the LLM. - - Args: - messages_batch (list | list[list]): Either one prompt or a list of prompts in message format. - kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - pass - - @property - @abstractmethod - def max_ctx_len(self) -> int: - """The maximum context length of the LLM.""" - pass - - @property - @abstractmethod - def max_tokens(self) -> int: - """The maximum number of tokens that can be generated by the LLM.""" - pass +import numpy as np +from litellm import batch_completion, token_counter +from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse + +from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade + + +class LM: + def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs): + self.model = model + self.max_ctx_len = max_ctx_len + self.max_tokens = max_tokens + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + self.history = [] + + def __call__(self, messages=None, **kwargs) -> LMOutput: + kwargs = {**self.kwargs, **kwargs} + if kwargs.get("logprobs", False): + kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10) + + responses: list[ModelResponse] = batch_completion(model=self.model, messages=messages, **kwargs) + outputs = [self._get_top_choice(resp) for resp in responses] + logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None + + return LMOutput(outputs=outputs, logprobs=logprobs) + + def _get_top_choice(self, response: ModelResponse) -> str: + return response.choices[0].message.content + + def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: + logprobs = response.choices[0].logprobs["content"] + return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs] + + def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: + all_tokens = [] + all_confidences = [] + for resp in range(len(logprobs)): + tokens = [logprob.token for logprob in logprobs[resp]] + confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]] + all_tokens.append(tokens) + all_confidences.append(confidences) + return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) + + def format_logprobs_for_filter_cascade( + self, logprobs: list[list[ChatCompletionTokenLogprob]] + ) -> LogprobsForFilterCascade: + all_tokens = [] + all_confidences = [] + all_true_probs = [] + + for resp in range(len(logprobs)): + all_tokens.append([logprob.token for logprob in logprobs[resp]]) + all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]]) + top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]} + true_prob, false_prob = 0, 0 + if top_logprobs and "True" in top_logprobs and "False" in top_logprobs: + true_prob = np.exp(top_logprobs["True"]) + false_prob = np.exp(top_logprobs["False"]) + all_true_probs.append(true_prob / (true_prob + false_prob)) + else: + all_true_probs.append(1 if "True" in top_logprobs else 0) + return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs) + + def count_tokens(self, messages: list[dict[str, str]] | str) -> int: + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + return token_counter(model=self.model, messages=messages) diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py deleted file mode 100644 index 57fb20eb..00000000 --- a/lotus/models/openai_model.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import threading -from typing import Any - -import backoff -import numpy as np -import openai -import tiktoken -from openai import OpenAI -from transformers import AutoTokenizer - -import lotus -from lotus.models.lm import LM - -ERRORS = (openai.RateLimitError, openai.APIError) - - -def backoff_hdlr(details): - """Handler from https://pypi.org/project/backoff/""" - print( - "Backing off {wait:0.1f} seconds after {tries} tries " - "calling function {target} with kwargs " - "{kwargs}".format(**details), - ) - - -class OpenAIModel(LM): - """Wrapper around OpenAI, Databricks, and vLLM OpenAI server - - Args: - model (str): The name of the model to use. - api_key (str | None): An API key (e.g. from OpenAI or Databricks). - api_base (str | None): The endpoint of the server. - provider (str): Either openai, dbrx, or vllm. - max_batch_size (int): The maximum batch size for the model. - max_ctx_len (int): The maximum context length for the model. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - """ - - def __init__( - self, - model: str = "gpt-4o-mini", - hf_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, - provider: str = "openai", - max_batch_size: int = 64, - max_ctx_len: int = 4096, - **kwargs: dict[str, Any], - ): - super().__init__() - self.provider = provider - self.use_chat = provider in ["openai", "dbrx", "ollama"] - self.max_batch_size = max_batch_size - self.hf_name = hf_name if hf_name is not None else model - self.__dict__["max_ctx_len"] = max_ctx_len - - self.kwargs = { - "model": model, - "temperature": 0.0, - "max_tokens": 512, - "top_p": 1, - "n": 1, - **kwargs, - } - - api_key = api_key or os.environ.get("OPENAI_API_KEY", "None") - self.client = OpenAI(api_key=api_key if api_key else "None", base_url=api_base) - - # TODO: Refactor this - if self.provider == "openai": - self.tokenizer = tiktoken.encoding_for_model(model) - else: - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name) - - def handle_chat_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single chat request to OpenAI server. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch (just one in this case). If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if kwargs.get("logprobs", False): - kwargs["top_logprobs"] = 10 - - kwargs = {**self.kwargs, **kwargs} - kwargs["messages"] = messages - response = self.chat_request(**kwargs) - - choices = response["choices"] - completions = [c["message"]["content"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - def handle_completion_request( - self, messages: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle a potentially batched completions request to OpenAI server. - - Args: - messages_batch (list): A list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - if not isinstance(messages[0], list): - prompt = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)] - else: - prompt = [ - self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in messages - ] - - kwargs = {**self.kwargs, **kwargs} - kwargs["prompt"] = prompt - if kwargs.get("logprobs", False): - kwargs["logprobs"] = 10 - response = self.completion_request(**kwargs) - - choices = response["choices"] - completions = [c["text"] for c in choices] - - if kwargs.get("logprobs", False): - logprobs = [c["logprobs"] for c in choices] - return completions, logprobs - - return completions - - @backoff.on_exception( - backoff.expo, - ERRORS, - max_time=1000, - on_backoff=backoff_hdlr, - ) - def request(self, messages: list, **kwargs: dict[str, Any]) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle single request to OpenAI server. Decides whether chat or completion endpoint is necessary. - - Args: - messages_batch (list): A prompt in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - A list of text outputs for each prompt in the batch (just one in this case). - If logprobs is specified in the keyword arguments, hen a list of logprobs is also returned (also of size one). - """ - if self.use_chat: - return self.handle_chat_request(messages, **kwargs) - else: - return self.handle_completion_request(messages, **kwargs) - - def batched_chat_request( - self, messages_batch: list, **kwargs: dict[str, Any] - ) -> list | tuple[list[list[str]], list[list[float]]]: - """Handle batched chat request to OpenAI server. - - Args: - messages_batch (list): Either one prompt or a list of prompts in message format. - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. - - Returns: - list | tuple[list[list[str]], list[list[float]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, - then a list of logprobs is also returned. - """ - - batch_size = len(messages_batch) - text_ret = [None] * batch_size - logprobs_ret = [None] * batch_size - threads = [] - - def thread_function(idx, messages, kwargs): - text = self(messages, **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = text - logprobs_ret[idx] = logprobs[0] - text_ret[idx] = text[0] - - for idx, messages in enumerate(messages_batch): - thread = threading.Thread(target=thread_function, args=(idx, messages, kwargs)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - - return text_ret - - def __call__( - self, messages_batch: list | list[list], **kwargs: dict[str, Any] - ) -> list[str] | tuple[list[str], list[dict[str, Any]]]: - lotus.logger.debug(f"OpenAIModel.__call__ messages_batch: {messages_batch}") - lotus.logger.debug(f"OpenAIModel.__call__ kwargs: {kwargs}") - # Bakes max batch size into model call. # TODO: Figure out less hacky way to do this. - if isinstance(messages_batch[0], list) and len(messages_batch) > self.max_batch_size: - text_ret = [] - logprobs_ret = [] - for i in range(0, len(messages_batch), self.max_batch_size): - res = self(messages_batch[i : i + self.max_batch_size], **kwargs) - if kwargs.get("logprobs", False): - text, logprobs = res - logprobs_ret.extend(logprobs) - else: - text = res - text_ret.extend(text) - - if kwargs.get("logprobs", False): - return text_ret, logprobs_ret - return text_ret - - if self.use_chat and isinstance(messages_batch[0], list): - return self.batched_chat_request(messages_batch, **kwargs) - - return self.request(messages_batch, **kwargs) - - def count_tokens(self, prompt: str | list) -> int: - if isinstance(prompt, str): - if self.provider != "openai": - return len(self.tokenizer(prompt)["input_ids"]) - - return len(self.tokenizer.encode(prompt)) - else: - if self.provider != "openai": - return len(self.tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True)) - - return sum(len(self.tokenizer.encode(message["content"])) for message in prompt) - - def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences - - def format_logprobs_for_filter_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: - all_tokens = [] - all_confidences = [] - all_true_probs = [] - for idx in range(len(logprobs)): - if self.provider == "vllm": - tokens = logprobs[idx]["tokens"] - confidences = np.exp(logprobs[idx]["token_logprobs"]) - top_logprobs = logprobs[idx]["top_logprobs"][0] - if 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - elif self.provider == "openai": - content = logprobs[idx]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) - top_logprobs = {x["token"]:x["logprob"] for x in content[0]["top_logprobs"]} - - true_prob, false_prob = 0, 0 - if top_logprobs and 'True' in top_logprobs and 'False' in top_logprobs: - true_prob = np.exp(top_logprobs['True']) - false_prob = np.exp(top_logprobs['False']) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if 'True' in top_logprobs else 0) - - all_tokens.append(tokens) - all_confidences.append(confidences) - - return all_tokens, all_confidences, all_true_probs - - def chat_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send chat request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI chat completion response. - """ - return self.client.chat.completions.create(**kwargs).model_dump() - - def completion_request(self, **kwargs: dict[str, Any]) -> dict[str, Any]: - """Send completion request to OpenAI server. - - Args: - **kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify things such as the prompt, temperature, - model name, max tokens, etc. - - Returns: - dict: OpenAI completion response. - """ - return self.client.completions.create(**kwargs).model_dump() - - @property - def max_tokens(self) -> int: - return self.kwargs["max_tokens"] - - @property - def max_ctx_len(self) -> int: - return self.__dict__["max_ctx_len"] diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 3302a493..7aab77db 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -12,43 +12,47 @@ def importance_sampling( w = np.sqrt(proxy_scores) w = 0.5 * w / np.sum(w) + 0.5 * np.ones((len(proxy_scores))) / len(proxy_scores) indices = np.arange(len(proxy_scores)) - sample_size = (int) (sample_percentage * len(proxy_scores)) + sample_size = (int)(sample_percentage * len(proxy_scores)) sample_indices = np.random.choice(indices, sample_size, p=w) - correction_factors = (1/len(proxy_scores)) / w + correction_factors = (1 / len(proxy_scores)) / w return sample_indices, correction_factors + def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = 50 quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) + true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles true_probs = np.clip(true_probs, 0, 1) return true_probs + def learn_cascade_thresholds( proxy_scores: list[float], oracle_outputs: list[float], sample_correction_factors: list[float], recall_target: float, precision_target: float, - delta: float + delta: float, ) -> tuple[tuple[float, float], int]: - """Learns cascade thresholds given targets and proxy scores, + """Learns cascade thresholds given targets and proxy scores, oracle outputs over the sample, and correction factors for the sample.""" def UB(mean, std_dev, s, delta): - return mean + (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + return mean + (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) def LB(mean, std_dev, s, delta): - return mean - (std_dev / (s ** 0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + return mean - (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold] total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs) - recall = (sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle)) / total_correct + recall = ( + sum(1 for x in helper_accepted if x[0] >= pos_threshold and x[1]) + sum(x[1] * x[2] for x in sent_to_oracle) + ) / total_correct return recall def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: @@ -65,10 +69,12 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True) sample_size = len(sorted_pairs) - best_combination = (1,0) # initial tau_+, tau_- + best_combination = (1, 0) # initial tau_+, tau_- # Find tau_negative based on recall - tau_neg_0 = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target) + tau_neg_0 = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= recall_target + ) best_combination = (best_combination[0], tau_neg_0) # Do a statistical correction to get a new target recall @@ -80,9 +86,13 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: mean_z2 = np.mean(Z2) if Z2 else 0 std_z2 = np.std(Z2) if Z2 else 0 - corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta/2)/(UB(mean_z1, std_z1, sample_size, delta/2) + LB(mean_z2, std_z2, sample_size, delta/2)) + corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta / 2) / ( + UB(mean_z1, std_z1, sample_size, delta / 2) + LB(mean_z2, std_z2, sample_size, delta / 2) + ) corrected_recall_target = min(1, corrected_recall_target) - tau_neg_prime = max(x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target) + tau_neg_prime = max( + x[0] for x in sorted_pairs[::-1] if recall(best_combination[0], x[0], sorted_pairs) >= corrected_recall_target + ) best_combination = (best_combination[0], tau_neg_prime) # Do a statistical correction to get a target satisfying precision @@ -92,7 +102,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold] mean_z = np.mean(Z) if Z else 0 std_z = np.std(Z) if Z else 0 - p_l = LB(mean_z, std_z, len(Z), delta/len(sorted_pairs)) + p_l = LB(mean_z, std_z, len(Z), delta / len(sorted_pairs)) if p_l > precision_target: candidate_thresholds.append(possible_threshold) @@ -105,6 +115,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: return best_combination, oracle_calls + def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: true_score = np.clip(true_score, 0, 1) - return true_score \ No newline at end of file + return true_score diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 6d77f8fe..6fd9e8b0 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -2,9 +2,9 @@ import pandas as pd -import lotus +import lotus.models from lotus.templates import task_instructions -from lotus.types import SemanticAggOutput +from lotus.types import LMOutput, SemanticAggOutput def sem_agg( @@ -108,13 +108,9 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) new_partition_ids.append(cur_partition_id) - result = model(batch) + lm_output: LMOutput = model(batch) - # TODO: this is a weird hack for model typing - if isinstance(result, tuple): - summaries, _ = result - else: - summaries = result + summaries = lm_output.outputs partition_ids = new_partition_ids new_partition_ids = [] diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 6336b196..82e82a98 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput from .postprocessors import extract_postprocess @@ -36,15 +36,11 @@ def sem_extract( inputs.append(prompt) # call model - raw_outputs = model(inputs) - if isinstance(raw_outputs, tuple): - raw_outputs, _ = raw_outputs - else: - assert isinstance(raw_outputs, list) + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"quotes: {postprocess_output.quotes}") diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index aafd6aa6..76e8b9b4 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -4,9 +4,9 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticFilterOutput +from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput -from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds +from .cascade_utils import importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -45,20 +45,17 @@ def sem_filter( lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) kwargs: dict[str, Any] = {"logprobs": logprobs} - res = model(inputs, **kwargs) - if logprobs: - assert isinstance(res, tuple) - raw_outputs, raw_logprobs = res - else: - assert isinstance(res, list) - raw_outputs = res - - postprocess_output = filter_postprocess(raw_outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]) + lm_output: LMOutput = model(inputs, **kwargs) + + postprocess_output = filter_postprocess( + lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") - return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=raw_logprobs if logprobs else None) + return SemanticFilterOutput(**postprocess_output.model_dump(), logprobs=lm_output.logprobs if logprobs else None) + def learn_filter_cascade_thresholds( sample_df_txt: str, @@ -75,8 +72,8 @@ def learn_filter_cascade_thresholds( cot_reasoning: list | None = None, strategy: str | None = None, ) -> tuple[float, float]: - """Automatically learns the cascade thresholds for a cascade - filter given a sample of data and doing a search across threshold + """Automatically learns the cascade thresholds for a cascade + filter given a sample of data and doing a search across threshold to see what threshold gives the best accuracy.""" try: @@ -97,7 +94,7 @@ def learn_filter_cascade_thresholds( sample_correction_factors=sample_correction_factors, recall_target=recall_target, precision_target=precision_target, - delta=delta + delta=delta, ) lotus.logger.info(f"Learned cascade thresholds: {best_combination}") @@ -107,6 +104,7 @@ def learn_filter_cascade_thresholds( lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") return None + @pd.api.extensions.register_dataframe_accessor("sem_filter") class SemFilterDataframe: """DataFrame accessor for semantic filter.""" @@ -198,14 +196,16 @@ def __call__( if helper_strategy == "cot": helper_cot_reasoning = examples["Reasoning"].tolist() - + if learn_cascade_threshold_sample_percentage and lotus.settings.helper_lm: if helper_strategy == "cot": lotus.logger.error("CoT not supported for helper models in cascades.") raise Exception if recall_target is None or precision_target is None or failure_probability is None: - lotus.logger.error("Recall target, precision target, and confidence need to be specified for learned thresholds.") + lotus.logger.error( + "Recall target, precision target, and confidence need to be specified for learned thresholds." + ) raise Exception # Run small LM and get logits @@ -221,11 +221,14 @@ def __call__( strategy=helper_strategy, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs - _, _, helper_true_probs = lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) - - helper_true_probs = calibrate_llm_logprobs(helper_true_probs) + formatted_helper_logprobs: LogprobsForFilterCascade = ( + lotus.settings.helper_lm.format_logprobs_for_filter_cascade(helper_logprobs) + ) + helper_true_probs = formatted_helper_logprobs.true_probs - sample_indices, correction_factors = importance_sampling(helper_true_probs, learn_cascade_threshold_sample_percentage) + sample_indices, correction_factors = importance_sampling( + helper_true_probs, learn_cascade_threshold_sample_percentage + ) sample_df = self._obj.loc[sample_indices] sample_df_txt = task_instructions.df2text(sample_df, col_li) sample_helper_true_probs = [helper_true_probs[i] for i in sample_indices] @@ -238,7 +241,7 @@ def __call__( default=default, recall_target=recall_target, precision_target=precision_target, - delta=failure_probability/2, + delta=failure_probability / 2, helper_true_probs=sample_helper_true_probs, sample_correction_factors=sample_correction_factors, examples_df_txt=examples_df_txt, @@ -261,7 +264,13 @@ def __call__( true_prob = helper_true_probs[idx_i] if true_prob >= pos_cascade_threshold or true_prob <= neg_cascade_threshold: high_conf_idxs.add(idx_i) - helper_outputs[idx_i] = True if true_prob >= pos_cascade_threshold else False if true_prob <= neg_cascade_threshold else helper_outputs[idx_i] + helper_outputs[idx_i] = ( + True + if true_prob >= pos_cascade_threshold + else False + if true_prob <= neg_cascade_threshold + else helper_outputs[idx_i] + ) lotus.logger.info(f"Num routed to smaller model: {len(high_conf_idxs)}") stats["num_routed_to_helper_model"] = len(high_conf_idxs) diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 05f6bacc..0d01cc55 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -1,6 +1,5 @@ from typing import Any -import numpy as np import pandas as pd import lotus @@ -136,24 +135,17 @@ def sem_join_cascade( assert helper_logprobs is not None high_conf_idxs = set() - for idx_i in range(len(helper_outputs)): - tokens: list[str] - confidences: np.ndarray[Any, np.dtype[np.float64]] - # Get the logprobs - if lotus.settings.helper_lm.provider == "vllm": - tokens = helper_logprobs[idx_i]["tokens"] - confidences = np.exp(helper_logprobs[idx_i]["token_logprobs"]) - elif lotus.settings.helper_lm.provider == "openai": - content: list[dict[str, Any]] = helper_logprobs[idx_i]["content"] - tokens = [content[t_idx]["token"] for t_idx in range(len(content))] - confidences = np.exp([content[t_idx]["logprob"] for t_idx in range(len(content))]) + # Get the logprobs in a standardized format + formatted_logprobs = lotus.settings.helper_lm.format_logprobs_for_cascade(helper_logprobs) + tokens, confidences = formatted_logprobs.tokens, formatted_logprobs.confidences + for doc_idx in range(len(helper_outputs)): # Find where true/false is said and look at confidence - for idx_j in range(len(tokens) - 1, -1, -1): - if tokens[idx_j].strip(" \n").lower() in ["true", "false"]: - conf = confidences[idx_j] - if conf >= cascade_threshold: - high_conf_idxs.add(idx_i) + for token_idx in range(len(tokens[doc_idx]) - 1, -1, -1): + if tokens[doc_idx][token_idx].strip(" \n").lower() in ["true", "false"]: + confidence = confidences[doc_idx][token_idx] + if confidence >= cascade_threshold: + high_conf_idxs.add(doc_idx) # Send low confidence samples to large LM low_conf_idxs = sorted([i for i in range(len(helper_outputs)) if i not in high_conf_idxs]) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 7ee84e62..9074c094 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,7 +4,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput from .postprocessors import map_postprocess @@ -45,14 +45,11 @@ def sem_map( inputs.append(prompt) # call model - raw_outputs = model(inputs) - assert isinstance(raw_outputs, list) and all( - isinstance(item, str) for item in raw_outputs - ), "Model must return a list of strings" + lm_output: LMOutput = model(inputs) # post process results - postprocess_output = postprocessor(raw_outputs, strategy in ["cot", "zs-cot"]) - lotus.logger.debug(f"raw_outputs: {raw_outputs}") + postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"]) + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index a2b98a44..4db77a63 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -7,7 +7,7 @@ import lotus from lotus.templates import task_instructions -from lotus.types import SemanticTopKOutput +from lotus.types import LMOutput, SemanticTopKOutput def get_match_prompt_binary( @@ -65,8 +65,8 @@ def compare_batch_binary( match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) tokens += lotus.settings.lm.count_tokens(match_prompts[-1]) - results = lotus.settings.lm(match_prompts) - results = list(map(parse_ans_binary, results)) + results: LMOutput = lotus.settings.lm(match_prompts) + results = list(map(parse_ans_binary, results.outputs)) return results, tokens @@ -109,8 +109,8 @@ def compare_batch_binary_cascade( large_match_prompts.append(match_prompts[i]) large_tokens += lotus.settings.lm.count_tokens(large_match_prompts[-1]) - results = lotus.settings.lm(large_match_prompts) - for idx, res in enumerate(results): + results: LMOutput = lotus.settings.lm(large_match_prompts) + for idx, res in enumerate(results.outputs): new_idx = low_conf_idxs[idx] parsed_res = parse_ans_binary(res) parsed_results[new_idx] = parsed_res @@ -268,8 +268,8 @@ def __lt__(self, other: "HeapDoc") -> bool: prompt = get_match_prompt_binary(self.doc, other.doc, self.user_instruction, strategy=self.strategy) HeapDoc.num_calls += 1 HeapDoc.total_tokens += lotus.settings.lm.count_tokens(prompt) - result = lotus.settings.lm(prompt) - return parse_ans_binary(result[0]) + result: LMOutput = lotus.settings.lm([prompt]) + return parse_ans_binary(result.outputs[0]) def llm_heapsort( diff --git a/lotus/types.py b/lotus/types.py index a33339d5..7852754f 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -1,5 +1,6 @@ from typing import Any +from litellm.types.utils import ChatCompletionTokenLogprob from pydantic import BaseModel @@ -9,10 +10,11 @@ class StatsMixin(BaseModel): # TODO: Figure out better logprobs type class LogprobsMixin(BaseModel): - logprobs: list[dict[str, Any]] | None = None + # for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob + logprobs: list[list[ChatCompletionTokenLogprob]] | None = None -class SemanticMapPostprocessOutput(StatsMixin, LogprobsMixin): +class SemanticMapPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[str] explanations: list[str | None] @@ -55,3 +57,16 @@ class SemanticJoinOutput(StatsMixin): class SemanticTopKOutput(StatsMixin): indexes: list[int] + + +class LMOutput(LogprobsMixin): + outputs: list[str] + + +class LogprobsForCascade(BaseModel): + tokens: list[list[str]] + confidences: list[list[float]] + + +class LogprobsForFilterCascade(LogprobsForCascade): + true_probs: list[float] From 761df36e783bd6558ed2f932c04cfb3869ff08d7 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 30 Oct 2024 17:52:04 -0700 Subject: [PATCH 02/28] Update requirements --- docs/requirements-docs.txt | 2 +- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 101058ec..404a5c89 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,8 +3,8 @@ sphinx-rtd-theme==2.0.0 backoff==2.2.1 faiss-cpu==1.8.0.post1 +litellm==1.51.0 numpy==1.26.4 -openai==1.35.13 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 diff --git a/pyproject.toml b/pyproject.toml index 1ce25c08..3db30d79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ classifiers = [ dependencies = [ "backoff>=2.2.1,<3.0.0", "faiss-cpu>=1.8.0.post1,<2.0.0", + "litellm>=1.51.0,<2.0.0", "numpy>=1.25.0,<2.0.0", - "openai>=1.35.13,<2.0.0", "pandas>=2.0.0,<3.0.0", "sentence-transformers>=3.0.1,<4.0.0", "tiktoken>=0.7.0,<1.0.0", diff --git a/requirements.txt b/requirements.txt index ba74caf9..655dde54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ backoff==2.2.1 faiss-cpu==1.8.0.post1 +litellm==1.51.0 numpy==1.26.4 -openai==1.35.13 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 From bf8c73a50d17919f9827e788a048c6abcc930763 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 15:25:32 -0700 Subject: [PATCH 03/28] token counting + tests --- .github/tests/lm_tests.py | 53 ++++++++++++++++++++++++-------- lotus/models/lm.py | 64 +++++++++++++++++++++++++++------------ lotus/types.py | 1 - 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 35a0a5c9..2d661902 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -3,21 +3,22 @@ import lotus from lotus.models import LM +from tokenizers import Tokenizer # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") @pytest.fixture -def setup_models(): +def setup_gpt_models(): # Setup GPT models gpt_4o_mini = LM(model="gpt-4o-mini") gpt_4o = LM(model="gpt-4o") return gpt_4o_mini, gpt_4o -def test_filter_operation(setup_models): - gpt_4o_mini, _ = setup_models +def test_filter_operation(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) # Test filter operation on an easy dataframe @@ -30,8 +31,8 @@ def test_filter_operation(setup_models): assert filtered_df.equals(expected_df) -def test_filter_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models +def test_filter_cascade(setup_gpt_models): + gpt_4o_mini, gpt_4o = setup_gpt_models lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) data = { @@ -99,8 +100,8 @@ def test_filter_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] > 0, stats -def test_top_k(setup_models): - gpt_4o_mini, _ = setup_models +def test_top_k(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data = { @@ -120,8 +121,8 @@ def test_top_k(setup_models): assert top_2_expected == top_2_actual -def test_join(setup_models): - gpt_4o_mini, _ = setup_models +def test_join(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data1 = {"School": ["UC Berkeley", "Stanford"]} @@ -136,8 +137,8 @@ def test_join(setup_models): assert joined_pairs == expected_pairs -def test_join_cascade(setup_models): - gpt_4o_mini, gpt_4o = setup_models +def test_join_cascade(setup_gpt_models): + gpt_4o_mini, gpt_4o = setup_gpt_models lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) data1 = {"School": ["UC Berkeley", "Stanford"]} @@ -163,8 +164,8 @@ def test_join_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] == 0, stats -def test_map_fewshot(setup_models): - gpt_4o_mini, _ = setup_models +def test_map_fewshot(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) data = {"School": ["UC Berkeley", "Carnegie Mellon"]} @@ -177,3 +178,29 @@ def test_map_fewshot(setup_models): pairs = set(zip(df["School"], df["State"])) expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) assert pairs == expected_pairs + +def test_agg_then_map(setup_gpt_models): + _, gpt_4o = setup_gpt_models + lotus.settings.configure(lm=gpt_4o) + + data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} + df = pd.DataFrame(data) + agg_instruction = "What is the most common name in {Text}?" + agg_df = df.sem_agg(agg_instruction, suffix="draft_output") + map_instruction = f"{{draft_output}} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" + cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") + assert cleaned_df["final_output"].values[0] == "John" + +def test_count_tokens(setup_gpt_models): + gpt_4o_mini, _ = setup_gpt_models + lotus.settings.configure(lm=gpt_4o_mini) + + tokens = gpt_4o_mini.count_tokens("Hello, world!") + assert gpt_4o_mini.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 + + custom_tokenizer = Tokenizer.from_pretrained("gpt2") + custom_lm = LM(model="doesn't matter", tokenizer=custom_tokenizer) + tokens = custom_lm.count_tokens("Hello, world!") + assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + assert tokens < 100 diff --git a/lotus/models/lm.py b/lotus/models/lm.py index e878bb4f..795ebcf6 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,15 +1,17 @@ import numpy as np from litellm import batch_completion, token_counter from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse +from tokenizers import Tokenizer from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade class LM: - def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs): + def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, tokenizer: Tokenizer = None, **kwargs): self.model = model self.max_ctx_len = max_ctx_len self.max_tokens = max_tokens + self.tokenizer = tokenizer self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] @@ -34,9 +36,9 @@ def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompleti def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: all_tokens = [] all_confidences = [] - for resp in range(len(logprobs)): - tokens = [logprob.token for logprob in logprobs[resp]] - confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]] + for resp_logprobs in logprobs: + tokens = [logprob.token for logprob in resp_logprobs] + confidences = [np.exp(logprob.logprob) for logprob in resp_logprobs] all_tokens.append(tokens) all_confidences.append(confidences) return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) @@ -44,24 +46,48 @@ def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLog def format_logprobs_for_filter_cascade( self, logprobs: list[list[ChatCompletionTokenLogprob]] ) -> LogprobsForFilterCascade: - all_tokens = [] - all_confidences = [] + # Get base cascade format first + base_cascade = self.format_logprobs_for_cascade(logprobs) all_true_probs = [] - for resp in range(len(logprobs)): - all_tokens.append([logprob.token for logprob in logprobs[resp]]) - all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]]) - top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]} - true_prob, false_prob = 0, 0 - if top_logprobs and "True" in top_logprobs and "False" in top_logprobs: - true_prob = np.exp(top_logprobs["True"]) - false_prob = np.exp(top_logprobs["False"]) - all_true_probs.append(true_prob / (true_prob + false_prob)) - else: - all_true_probs.append(1 if "True" in top_logprobs else 0) - return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs) + def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: + if "True" in token_probs and "False" in token_probs: + true_prob = token_probs["True"] + false_prob = token_probs["False"] + return true_prob / (true_prob + false_prob) + return None + + # Get true probabilities for filter cascade + for resp_idx, response_logprobs in enumerate(logprobs): + true_prob = None + for logprob in response_logprobs: + token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} + true_prob = get_normalized_true_prob(token_probs) + if true_prob is not None: + break + + # Default to 1 if "True" in tokens, 0 if not + if true_prob is None: + true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 + + all_true_probs.append(true_prob) + + return LogprobsForFilterCascade( + tokens=base_cascade.tokens, + confidences=base_cascade.confidences, + true_probs=all_true_probs + ) def count_tokens(self, messages: list[dict[str, str]] | str) -> int: + """Count tokens in messages using either custom tokenizer or model's default tokenizer""" if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - return token_counter(model=self.model, messages=messages) + + kwargs = {"model": self.model, "messages": messages} + if self.tokenizer: + kwargs["custom_tokenizer"] = { + "type": "huggingface_tokenizer", + "tokenizer": self.tokenizer + } + + return token_counter(**kwargs) diff --git a/lotus/types.py b/lotus/types.py index 7852754f..6e11f93a 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -8,7 +8,6 @@ class StatsMixin(BaseModel): stats: dict[str, Any] | None = None -# TODO: Figure out better logprobs type class LogprobsMixin(BaseModel): # for each response, we have a list of tokens, and for each token, we have a ChatCompletionTokenLogprob logprobs: list[list[ChatCompletionTokenLogprob]] | None = None From 058a1727d3e5c3b06077de2a241567c9f777d232 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 15:46:19 -0700 Subject: [PATCH 04/28] Fix formatting --- .github/tests/lm_tests.py | 6 ++++-- lotus/models/lm.py | 25 ++++++++++++++----------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 2d661902..d81db54e 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -1,9 +1,9 @@ import pandas as pd import pytest +from tokenizers import Tokenizer import lotus from lotus.models import LM -from tokenizers import Tokenizer # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") @@ -179,6 +179,7 @@ def test_map_fewshot(setup_gpt_models): expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) assert pairs == expected_pairs + def test_agg_then_map(setup_gpt_models): _, gpt_4o = setup_gpt_models lotus.settings.configure(lm=gpt_4o) @@ -187,10 +188,11 @@ def test_agg_then_map(setup_gpt_models): df = pd.DataFrame(data) agg_instruction = "What is the most common name in {Text}?" agg_df = df.sem_agg(agg_instruction, suffix="draft_output") - map_instruction = f"{{draft_output}} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" + map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") assert cleaned_df["final_output"].values[0] == "John" + def test_count_tokens(setup_gpt_models): gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 795ebcf6..84262c89 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -7,7 +7,15 @@ class LM: - def __init__(self, model: str = "gpt-4o-mini", temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, tokenizer: Tokenizer = None, **kwargs): + def __init__( + self, + model: str = "gpt-4o-mini", + temperature: float = 0.0, + max_ctx_len: int = 128000, + max_tokens: int = 512, + tokenizer: Tokenizer = None, + **kwargs, + ): self.model = model self.max_ctx_len = max_ctx_len self.max_tokens = max_tokens @@ -69,25 +77,20 @@ def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: # Default to 1 if "True" in tokens, 0 if not if true_prob is None: true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 - + all_true_probs.append(true_prob) return LogprobsForFilterCascade( - tokens=base_cascade.tokens, - confidences=base_cascade.confidences, - true_probs=all_true_probs + tokens=base_cascade.tokens, confidences=base_cascade.confidences, true_probs=all_true_probs ) def count_tokens(self, messages: list[dict[str, str]] | str) -> int: """Count tokens in messages using either custom tokenizer or model's default tokenizer""" if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - + kwargs = {"model": self.model, "messages": messages} if self.tokenizer: - kwargs["custom_tokenizer"] = { - "type": "huggingface_tokenizer", - "tokenizer": self.tokenizer - } - + kwargs["custom_tokenizer"] = {"type": "huggingface_tokenizer", "tokenizer": self.tokenizer} + return token_counter(**kwargs) From 5fdb9717fd5dc9346eca18982c1cfe7d6c9f405e Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 17:25:56 -0700 Subject: [PATCH 05/28] Work more on typing --- .github/tests/lm_tests.py | 15 ++++--- examples/provider_examples/oai.py | 20 --------- examples/provider_examples/ollama.py | 25 ----------- examples/provider_examples/vllm.py | 24 ----------- lotus/__init__.py | 3 +- lotus/models/colbertv2_model.py | 12 +++--- lotus/models/e5_model.py | 9 ++-- lotus/models/lm.py | 64 ++++++++++++++++++++-------- lotus/models/reranker.py | 2 +- lotus/models/rm.py | 10 ++--- lotus/sem_ops/cascade_utils.py | 51 +++++++++------------- lotus/sem_ops/sem_filter.py | 15 ++++--- lotus/sem_ops/sem_topk.py | 10 ++--- lotus/settings.py | 2 + 14 files changed, 109 insertions(+), 153 deletions(-) delete mode 100644 examples/provider_examples/oai.py delete mode 100644 examples/provider_examples/ollama.py delete mode 100644 examples/provider_examples/vllm.py diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index d81db54e..e6a4c38c 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -114,11 +114,14 @@ def test_top_k(setup_gpt_models): } df = pd.DataFrame(data) user_instruction = "Which {Text} is most related to basketball?" - sorted_df = df.sem_topk(user_instruction, K=2) - top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"]) - top_2_actual = set(sorted_df["Text"].values) - assert top_2_expected == top_2_actual + + strategies = ["quick", "heap", "naive"] + for strategy in strategies: + sorted_df = df.sem_topk(user_instruction, K=2, strategy=strategy) + + top_2_actual = set(sorted_df["Text"].values) + assert top_2_expected == top_2_actual def test_join(setup_gpt_models): @@ -181,8 +184,8 @@ def test_map_fewshot(setup_gpt_models): def test_agg_then_map(setup_gpt_models): - _, gpt_4o = setup_gpt_models - lotus.settings.configure(lm=gpt_4o) + gpt_4o_mini, _ = setup_gpt_models + lotus.settings.configure(lm=gpt_4o_mini) data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} df = pd.DataFrame(data) diff --git a/examples/provider_examples/oai.py b/examples/provider_examples/oai.py deleted file mode 100644 index ee96e876..00000000 --- a/examples/provider_examples/oai.py +++ /dev/null @@ -1,20 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import LM - -lm = LM() - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/examples/provider_examples/ollama.py b/examples/provider_examples/ollama.py deleted file mode 100644 index 8eb967ad..00000000 --- a/examples/provider_examples/ollama.py +++ /dev/null @@ -1,25 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import LM - -lm = LM( - api_base="http://localhost:11434/v1", - model="llama3.2", - hf_name="meta-llama/Llama-3.2-3B-Instruct", - provider="ollama", -) - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/examples/provider_examples/vllm.py b/examples/provider_examples/vllm.py deleted file mode 100644 index 70975f95..00000000 --- a/examples/provider_examples/vllm.py +++ /dev/null @@ -1,24 +0,0 @@ -import pandas as pd - -import lotus -from lotus.models import LM - -lm = LM( - model="meta-llama/Meta-Llama-3.1-70B-Instruct", - api_base="http://localhost:8000/v1", - provider="vllm", -) - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) -user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) -print(df) diff --git a/lotus/__init__.py b/lotus/__init__.py index 190d9d52..58f4575e 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -19,7 +19,8 @@ sem_dedup, sem_topk, ) -from lotus.settings import settings +from lotus.settings import settings # type: ignore[attr-defined] + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_model.py index 2c407a2c..a706d67b 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_model.py @@ -2,14 +2,16 @@ from typing import Any from lotus.models.rm import RM +from numpy.typing import NDArray +import numpy as np class ColBERTv2Model(RM): """ColBERTv2 Model""" - def __init__(self, **kwargs): + def __init__(self) -> None: self.docs: list[str] | None = None - self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs} + self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} self.index_dir: str | None = None from colbert import Indexer, Searcher @@ -41,7 +43,7 @@ def load_index(self, index_dir: str) -> None: with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "rb") as fp: self.docs = pickle.load(fp) - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float_]: raise NotImplementedError("This method is not implemented for ColBERTv2Model") def __call__( @@ -57,8 +59,8 @@ def __call__( searcher = self.Searcher(index=f"{self.index_dir}/index", collection=self.docs) # make queries a dict with keys as query ids - queries = {i: q for i, q in enumerate(queries)} - all_results = searcher.search_all(queries, k=k).todict() + queries_dict = {i: q for i, q in enumerate(queries)} + all_results = searcher.search_all(queries_dict, k=k).todict() indices = [[result[0] for result in all_results[qid]] for qid in all_results.keys()] distances = [[result[2] for result in all_results[qid]] for qid in all_results.keys()] diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py index 310a2428..cc8dc2ee 100644 --- a/lotus/models/e5_model.py +++ b/lotus/models/e5_model.py @@ -3,6 +3,7 @@ from typing import Any import numpy as np +from numpy.typing import NDArray import torch import torch.nn.functional as F from tqdm import tqdm @@ -25,7 +26,7 @@ def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs} self.batch_size: int = 100 - self.vecs: np.ndarray[Any, np.dtype[np.float32]] | None = None + self.vecs: NDArray[np.float_] | None = None import faiss @@ -45,7 +46,7 @@ def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.T last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> np.ndarray[Any, np.dtype[np.float32]]: + def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> NDArray[np.float_]: """Run the embedding model. Args: @@ -111,9 +112,9 @@ def load_index(self, index_dir: str) -> None: self.vecs = pickle.load(fp) @classmethod - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list[np.ndarray[Any, np.dtype[np.float32]]]: + def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float_]: with open(f"{index_dir}/vecs", "rb") as fp: - vecs: np.ndarray[Any, np.dtype[np.float32]] = pickle.load(fp) + vecs: NDArray[np.float_] = pickle.load(fp) return vecs[ids] diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 84262c89..9eb03c1d 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,6 +1,9 @@ +from typing import Any + import numpy as np -from litellm import batch_completion, token_counter -from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse +from litellm import batch_completion +from litellm.utils import token_counter +from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse, Choices from tokenizers import Tokenizer from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade @@ -13,32 +16,52 @@ def __init__( temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, - tokenizer: Tokenizer = None, - **kwargs, + tokenizer: Tokenizer | None = None, + **kwargs: dict[str, Any], ): self.model = model self.max_ctx_len = max_ctx_len self.max_tokens = max_tokens self.tokenizer = tokenizer self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) - self.history = [] - - def __call__(self, messages=None, **kwargs) -> LMOutput: - kwargs = {**self.kwargs, **kwargs} - if kwargs.get("logprobs", False): - kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10) - responses: list[ModelResponse] = batch_completion(model=self.model, messages=messages, **kwargs) + def __call__( + self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any] + ) -> LMOutput: + kwargs_for_batch = self._format_batch_kwargs(kwargs) + responses: list[ModelResponse] = batch_completion( + self.model, + messages, + temperature=kwargs_for_batch.get("temperature"), + max_tokens=kwargs_for_batch.get("max_tokens"), + top_logprobs=kwargs_for_batch.get("top_logprobs"), + logprobs=kwargs_for_batch.get("logprobs") + ) outputs = [self._get_top_choice(resp) for resp in responses] - logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None + logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None return LMOutput(outputs=outputs, logprobs=logprobs) + def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: + all_kwargs = {**self.kwargs, **kwargs} + if all_kwargs.get("logprobs", False): + all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) + return { + k: v for k, v in all_kwargs.items() + if k in ["temperature", "max_tokens", "top_logprobs", "logprobs"] + } + def _get_top_choice(self, response: ModelResponse) -> str: - return response.choices[0].message.content + choice = response.choices[0] + assert isinstance(choice, Choices) + if choice.message.content is None: + raise ValueError(f"No content in response: {response}") + return choice.message.content def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: - logprobs = response.choices[0].logprobs["content"] + choice = response.choices[0] + assert isinstance(choice, Choices) + logprobs = choice.logprobs["content"] return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs] def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: @@ -89,8 +112,13 @@ def count_tokens(self, messages: list[dict[str, str]] | str) -> int: if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - kwargs = {"model": self.model, "messages": messages} + custom_tokenizer: dict[str, Any] | None = None if self.tokenizer: - kwargs["custom_tokenizer"] = {"type": "huggingface_tokenizer", "tokenizer": self.tokenizer} - - return token_counter(**kwargs) + custom_tokenizer = dict(type="huggingface_tokenizer", tokenizer=self.tokenizer) + + # Pass values directly rather than using kwargs dict to preserve typing + return token_counter( + custom_tokenizer=custom_tokenizer, + model=self.model, + messages=messages, + ) \ No newline at end of file diff --git a/lotus/models/reranker.py b/lotus/models/reranker.py index 736656f4..4e2f54ee 100644 --- a/lotus/models/reranker.py +++ b/lotus/models/reranker.py @@ -4,7 +4,7 @@ class Reranker(ABC): """Abstract class for reranker models.""" - def _init__(self): + def __init__(self) -> None: pass @abstractmethod diff --git a/lotus/models/rm.py b/lotus/models/rm.py index ed7b70e2..2301f298 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from typing import Any - - +from numpy.typing import NDArray +import numpy as np class RM(ABC): """Abstract class for retriever models.""" - def _init__(self): + def __init__(self) -> None: pass @abstractmethod @@ -28,7 +28,7 @@ def load_index(self, index_dir: str) -> None: pass @abstractmethod - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: + def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float_]: """Get the vectors from the index. Args: @@ -36,7 +36,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list: ids (list[int]): The ids of the vectors to retrieve Returns: - list: The vectors matching the specified ids. + NDArray[np.float_]: The vectors matching the specified ids. """ pass diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 7aab77db..d083112e 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.typing import NDArray import lotus @@ -6,7 +7,7 @@ def importance_sampling( proxy_scores: list[float], sample_percentage: float, -) -> tuple[list[int], list[float]]: +) -> tuple[NDArray[np.int_], NDArray[np.float_]]: """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) @@ -19,19 +20,10 @@ def importance_sampling( return sample_indices, correction_factors -def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: - """Transforms true probabilities to calibrate LLM proxies.""" - num_quantiles = 50 - quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles - true_probs = np.clip(true_probs, 0, 1) - return true_probs - - def learn_cascade_thresholds( proxy_scores: list[float], - oracle_outputs: list[float], - sample_correction_factors: list[float], + oracle_outputs: list[bool], + sample_correction_factors: NDArray[np.float_], recall_target: float, precision_target: float, delta: float, @@ -40,13 +32,13 @@ def learn_cascade_thresholds( oracle outputs over the sample, and correction factors for the sample.""" - def UB(mean, std_dev, s, delta): - return mean + (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + def UB(mean: float, std_dev: float, s: int, delta: float) -> float: + return float(mean + (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5)) - def LB(mean, std_dev, s, delta): - return mean - (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5) + def LB(mean: float, std_dev: float, s: int, delta: float) -> float: + return float(mean - (std_dev / (s**0.5)) * ((2 * np.log(1 / delta)) ** 0.5)) - def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + def recall(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if x[0] < pos_threshold and x[0] > neg_threshold] total_correct = sum(pair[1] * pair[2] for pair in sorted_pairs) @@ -55,7 +47,7 @@ def recall(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: ) / total_correct return recall - def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: + def precision(pos_threshold: float, neg_threshold: float, sorted_pairs: list[tuple[float, bool, float]]) -> float: helper_accepted = [x for x in sorted_pairs if x[0] >= pos_threshold or x[0] <= neg_threshold] sent_to_oracle = [x for x in sorted_pairs if pos_threshold > x[0] > neg_threshold] oracle_positive = sum(x[1] for x in sent_to_oracle) @@ -69,7 +61,7 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: sorted_pairs = sorted(paired_data, key=lambda x: x[0], reverse=True) sample_size = len(sorted_pairs) - best_combination = (1, 0) # initial tau_+, tau_- + best_combination = (1.0, 0.0) # initial tau_+, tau_- # Find tau_negative based on recall tau_neg_0 = max( @@ -81,10 +73,10 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: Z1 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] >= best_combination[1]] Z2 = [int(x[1]) * x[2] for x in sorted_pairs if x[0] < best_combination[1]] - mean_z1 = np.mean(Z1) if Z1 else 0 - std_z1 = np.std(Z1) if Z1 else 0 - mean_z2 = np.mean(Z2) if Z2 else 0 - std_z2 = np.std(Z2) if Z2 else 0 + mean_z1 = float(np.mean(Z1)) if Z1 else 0.0 + std_z1 = float(np.std(Z1)) if Z1 else 0.0 + mean_z2 = float(np.mean(Z2)) if Z2 else 0.0 + std_z2 = float(np.std(Z2)) if Z2 else 0.0 corrected_recall_target = UB(mean_z1, std_z1, sample_size, delta / 2) / ( UB(mean_z1, std_z1, sample_size, delta / 2) + LB(mean_z2, std_z2, sample_size, delta / 2) @@ -96,12 +88,12 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: best_combination = (best_combination[0], tau_neg_prime) # Do a statistical correction to get a target satisfying precision - candidate_thresholds = [1] + candidate_thresholds: list[float] = [1.0] for pair in sorted_pairs: possible_threshold = pair[0] Z = [int(x[1]) for x in sorted_pairs if x[0] >= possible_threshold] - mean_z = np.mean(Z) if Z else 0 - std_z = np.std(Z) if Z else 0 + mean_z = float(np.mean(Z)) if Z else 0.0 + std_z = float(np.std(Z)) if Z else 0.0 p_l = LB(mean_z, std_z, len(Z), delta / len(sorted_pairs)) if p_l > precision_target: candidate_thresholds.append(possible_threshold) @@ -109,13 +101,8 @@ def precision(pos_threshold: float, neg_threshold: float, sorted_pairs) -> bool: best_combination = (max(best_combination[1], min(candidate_thresholds)), best_combination[1]) oracle_calls = sum(1 for x in proxy_scores if best_combination[0] > x > best_combination[1]) - no_correction_sorted_pairs = [tup[:2] + (1,) for tup in sorted_pairs] + no_correction_sorted_pairs = [tup[:2] + (1.0,) for tup in sorted_pairs] lotus.logger.info(f"Sample recall: {recall(best_combination[0], best_combination[1], no_correction_sorted_pairs)}") lotus.logger.info(f"Sample precision: {precision(best_combination[0], best_combination[1], sorted_pairs)}") return best_combination, oracle_calls - - -def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: - true_score = np.clip(true_score, 0, 1) - return true_score diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 76e8b9b4..142cfb43 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -2,6 +2,9 @@ import pandas as pd +import numpy as np +from numpy.typing import NDArray + import lotus from lotus.templates import task_instructions from lotus.types import LMOutput, LogprobsForFilterCascade, SemanticFilterOutput @@ -58,7 +61,7 @@ def sem_filter( def learn_filter_cascade_thresholds( - sample_df_txt: str, + sample_df_txt: list[str], lm: lotus.models.LM, formatted_usr_instr: str, default: bool, @@ -66,10 +69,10 @@ def learn_filter_cascade_thresholds( precision_target: float, delta: float, helper_true_probs: list[float], - sample_correction_factors: list[float], - examples_df_txt: str | None = None, - examples_answers: str | None = None, - cot_reasoning: list | None = None, + sample_correction_factors: NDArray[np.float_], + examples_df_txt: list[str] | None = None, + examples_answers: list[bool] | None = None, + cot_reasoning: list[str] | None = None, strategy: str | None = None, ) -> tuple[float, float]: """Automatically learns the cascade thresholds for a cascade @@ -102,7 +105,7 @@ def learn_filter_cascade_thresholds( except Exception as e: lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") - return None + raise e @pd.api.extensions.register_dataframe_accessor("sem_filter") diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 4db77a63..43190e9a 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -59,14 +59,12 @@ def compare_batch_binary( pairs: list[tuple[str, str]], user_instruction: str, strategy: str | None = None ) -> tuple[list[bool], int]: match_prompts = [] - results = [] tokens = 0 for doc1, doc2 in pairs: match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) tokens += lotus.settings.lm.count_tokens(match_prompts[-1]) - - results: LMOutput = lotus.settings.lm(match_prompts) - results = list(map(parse_ans_binary, results.outputs)) + lm_results: LMOutput = lotus.settings.lm(match_prompts) + results: list[bool] = list(map(parse_ans_binary, lm_results.outputs)) return results, tokens @@ -109,8 +107,8 @@ def compare_batch_binary_cascade( large_match_prompts.append(match_prompts[i]) large_tokens += lotus.settings.lm.count_tokens(large_match_prompts[-1]) - results: LMOutput = lotus.settings.lm(large_match_prompts) - for idx, res in enumerate(results.outputs): + large_lm_results: LMOutput = lotus.settings.lm(large_match_prompts) + for idx, res in enumerate(large_lm_results.outputs): new_idx = low_conf_idxs[idx] parsed_res = parse_ans_binary(res) parsed_results[new_idx] = parsed_res diff --git a/lotus/settings.py b/lotus/settings.py index eb0e9feb..0be755b3 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,3 +1,5 @@ +# type: ignore + import copy import threading from contextlib import contextmanager From 252e0dce4b336e23e849c3e374e480d10cbe4267 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 17:26:20 -0700 Subject: [PATCH 06/28] Format --- lotus/models/colbertv2_model.py | 5 +++-- lotus/models/e5_model.py | 2 +- lotus/models/lm.py | 15 +++++++-------- lotus/models/rm.py | 5 ++++- lotus/sem_ops/sem_filter.py | 3 +-- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_model.py index a706d67b..3fd3bacf 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_model.py @@ -1,9 +1,10 @@ import pickle from typing import Any -from lotus.models.rm import RM -from numpy.typing import NDArray import numpy as np +from numpy.typing import NDArray + +from lotus.models.rm import RM class ColBERTv2Model(RM): diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py index cc8dc2ee..d354e7f5 100644 --- a/lotus/models/e5_model.py +++ b/lotus/models/e5_model.py @@ -3,9 +3,9 @@ from typing import Any import numpy as np -from numpy.typing import NDArray import torch import torch.nn.functional as F +from numpy.typing import NDArray from tqdm import tqdm from transformers import AutoModel, AutoTokenizer diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 9eb03c1d..149c0714 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -2,8 +2,8 @@ import numpy as np from litellm import batch_completion +from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter -from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse, Choices from tokenizers import Tokenizer from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade @@ -35,10 +35,12 @@ def __call__( temperature=kwargs_for_batch.get("temperature"), max_tokens=kwargs_for_batch.get("max_tokens"), top_logprobs=kwargs_for_batch.get("top_logprobs"), - logprobs=kwargs_for_batch.get("logprobs") + logprobs=kwargs_for_batch.get("logprobs"), ) outputs = [self._get_top_choice(resp) for resp in responses] - logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None + logprobs = ( + [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None + ) return LMOutput(outputs=outputs, logprobs=logprobs) @@ -46,10 +48,7 @@ def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: all_kwargs = {**self.kwargs, **kwargs} if all_kwargs.get("logprobs", False): all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) - return { - k: v for k, v in all_kwargs.items() - if k in ["temperature", "max_tokens", "top_logprobs", "logprobs"] - } + return {k: v for k, v in all_kwargs.items() if k in ["temperature", "max_tokens", "top_logprobs", "logprobs"]} def _get_top_choice(self, response: ModelResponse) -> str: choice = response.choices[0] @@ -121,4 +120,4 @@ def count_tokens(self, messages: list[dict[str, str]] | str) -> int: custom_tokenizer=custom_tokenizer, model=self.model, messages=messages, - ) \ No newline at end of file + ) diff --git a/lotus/models/rm.py b/lotus/models/rm.py index 2301f298..e4a158da 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod from typing import Any -from numpy.typing import NDArray + import numpy as np +from numpy.typing import NDArray + + class RM(ABC): """Abstract class for retriever models.""" diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 142cfb43..aa9141c8 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -1,8 +1,7 @@ from typing import Any -import pandas as pd - import numpy as np +import pandas as pd from numpy.typing import NDArray import lotus From 0aa1321e063ae64fdcddca30211a4361a93e287d Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 17:34:35 -0700 Subject: [PATCH 07/28] Typing --- lotus/sem_ops/cascade_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 5e692672..681e7ca9 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -25,7 +25,7 @@ def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) - true_probs = np.clip(true_probs, 0, 1) + true_probs = list(np.clip(true_probs, 0, 1)) return true_probs def learn_cascade_thresholds( From edb8cc6502a6cd22488df7f9af1c73a1a2e53cdd Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 17:34:48 -0700 Subject: [PATCH 08/28] Ruff format --- lotus/sem_ops/cascade_utils.py | 4 +++- lotus/settings.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 681e7ca9..2bee2c77 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -20,14 +20,16 @@ def importance_sampling( return sample_indices, correction_factors + def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: """Transforms true probabilities to calibrate LLM proxies.""" num_quantiles = lotus.settings.cascade_num_calibration_quantiles quantile_values = np.percentile(true_probs, np.linspace(0, 100, num_quantiles + 1)) - true_probs = ((np.digitize(true_probs, quantile_values) - 1) / num_quantiles) + true_probs = (np.digitize(true_probs, quantile_values) - 1) / num_quantiles true_probs = list(np.clip(true_probs, 0, 1)) return true_probs + def learn_cascade_thresholds( proxy_scores: list[float], oracle_outputs: list[bool], diff --git a/lotus/settings.py b/lotus/settings.py index 765a04ea..a928880c 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -115,4 +115,4 @@ def __repr__(self) -> str: # set defaults settings = Settings() -settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) \ No newline at end of file +settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50) From 8550730b8dd9ec95c28e3b8260798ba5f666dbc5 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 18:25:34 -0700 Subject: [PATCH 09/28] Fix types --- lotus/models/colbertv2_model.py | 4 ++-- lotus/models/cross_encoder_model.py | 12 ++++++------ lotus/models/e5_model.py | 28 ++++++++-------------------- lotus/models/rm.py | 10 +++++----- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_model.py index 3fd3bacf..595b970c 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_model.py @@ -49,10 +49,10 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f def __call__( self, - queries: str | list[str] | list[list[float]], + queries: str | list[str] | NDArray[np.float_], k: int, **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: + ) -> tuple[list[list[float]], list[list[int]]]: if isinstance(queries, str): queries = [queries] diff --git a/lotus/models/cross_encoder_model.py b/lotus/models/cross_encoder_model.py index 1f4c9512..f49aa59a 100644 --- a/lotus/models/cross_encoder_model.py +++ b/lotus/models/cross_encoder_model.py @@ -16,14 +16,14 @@ def __init__( self, model: str = "mixedbread-ai/mxbai-rerank-large-v1", device: str | None = None, - **kwargs, + batch_size: int = 32, ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - self.model = CrossEncoder(model, device=device, **kwargs) + self.device: str = device + self.batch_size: int = batch_size + self.model = CrossEncoder(model, device=device) def __call__(self, query: str, docs: list[str], k: int) -> list[int]: - results = self.model.rank(query, docs, top_k=k) - results = [result["corpus_id"] for result in results] - return results + results = self.model.rank(query, docs, top_k=k, batch_size=self.batch_size) + return [int(result["corpus_id"]) for result in results] diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py index d354e7f5..194d3a5e 100644 --- a/lotus/models/e5_model.py +++ b/lotus/models/e5_model.py @@ -118,35 +118,23 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl return vecs[ids] - def load_vecs(self, index_dir: str, ids: list[int]) -> list: - """loads vectors to the rm and returns them - Args: - index_dir (str): Directory of the index. - ids (list[int]): The ids of the vectors to retrieve - - Returns: - The vectors matching the specified ids. - """ - - if self.vecs is None: - with open(f"{index_dir}/vecs", "rb") as fp: - self.vecs = pickle.load(fp) - - return self.vecs[ids] - def __call__( self, - queries: str | list[str] | list[list[float]], + queries: str | list[str] | NDArray[np.float_], k: int, **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: + ) -> tuple[list[list[float]], list[list[int]]]: if isinstance(queries, str): queries = [queries] if isinstance(queries[0], str): - embedded_queries = self.embed(queries, **kwargs) + str_queries: list[str] = [str(q) for q in queries] + embedded_queries = self.embed(str_queries, **kwargs) else: - embedded_queries = queries + embedded_queries = np.asarray(queries, dtype=np.float32) + + if self.faiss_index is None: + raise ValueError("Index not loaded") distances, indicies = self.faiss_index.search(embedded_queries, k) diff --git a/lotus/models/rm.py b/lotus/models/rm.py index e4a158da..4ab13a9b 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -8,7 +8,7 @@ class RM(ABC): """Abstract class for retriever models.""" - def __init__(self) -> None: + def __init__(self): pass @abstractmethod @@ -47,18 +47,18 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl @abstractmethod def __call__( self, - queries: str | list[str] | list[list[float]], + queries: str | list[str] | NDArray[np.float_], k: int, **kwargs: dict[str, Any], - ) -> tuple[list[float], list[int]]: + ) -> tuple[list[list[float]], list[list[int]]]: """Run top-k search on the index. Args: - queries (str | list[str] | list[list[float]]): Either a query or a list of queries or a 2D FP32 array. + queries (str | list[str] | NDArray[np.float_]): Either a query or a list of queries or a 2D FP32 array. k (int): The k to use for top-k search. **kwargs (dict[str, Any]): Additional keyword arguments. Returns: - tuple[list[float], list[int]]: A tuple of (distances, indices) of the top-k vectors + tuple[list[list[float]], list[list[int]]]: A tuple of (distances, indices) of the top-k vectors """ pass From 90c6b3e587f0db2e5235d4a90cb85f5e6ef81637 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 18:27:51 -0700 Subject: [PATCH 10/28] Add mypy and pre-commit --- .pre-commit-config.yaml | 13 +++++++++++++ lotus/models/rm.py | 2 +- mypy.ini | 5 +++++ 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-config.yaml create mode 100644 mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..38997e65 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + args: ["--config-file", "mypy.ini"] diff --git a/lotus/models/rm.py b/lotus/models/rm.py index 4ab13a9b..85cf1af2 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -8,7 +8,7 @@ class RM(ABC): """Abstract class for retriever models.""" - def __init__(self): + def __init__(self) -> None: pass @abstractmethod diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..825f940b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +python_version = 3.10 +ignore_missing_imports = True +strict_optional = True +show_error_codes = True From b59728526a170f17cc46e587fdb42457421d1ea7 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 18:37:06 -0700 Subject: [PATCH 11/28] Add mypy to CI --- .github/workflows/tests.yml | 25 ++++++++++++++++++++++++- .gitignore | 4 +++- mypy.ini | 1 + requirements-dev.txt | 4 +++- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 35b6d58d..1991af16 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,11 +26,34 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.5.2 + pip install ruff==0.7.2 - name: Run ruff run: ruff check . + mypy: + name: Type Check + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.13.0 + pip install -r requirements.txt + pip install -e . + + - name: Run mypy + run: mypy lotus/ + test: name: Python Tests runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index f2118286..1a45a8d6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ __pycache__/ *.log dist/ docs/_build -.ruff_cache \ No newline at end of file +.ruff_cache +.mypy_cache +.pytest_cache \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 825f940b..0d73c326 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,3 +3,4 @@ python_version = 3.10 ignore_missing_imports = True strict_optional = True show_error_codes = True +files = lotus/**/*.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 883ede67..a5701f62 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,6 @@ -r requirements.txt # Additional development dependencies -ruff==0.5.2 +ruff==0.7.2 +mypy==1.13.0 +pytest==8.3.3 \ No newline at end of file From b093b6cc99551af2f5cac9e206d57b1aabf393fe Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 18:44:08 -0700 Subject: [PATCH 12/28] Add pre-commit to dev req --- requirements-dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index a5701f62..9fc528f5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,5 @@ # Additional development dependencies ruff==0.7.2 mypy==1.13.0 -pytest==8.3.3 \ No newline at end of file +pytest==8.3.3 +pre-commit==4.0.1 \ No newline at end of file From 0c86fc88082c387746b944bb211a8f374b995bb6 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 19:18:05 -0700 Subject: [PATCH 13/28] Add usage tracking --- .github/tests/lm_tests.py | 16 +++++++++++++++- lotus/models/lm.py | 26 ++++++++++++++++++++++++-- lotus/types.py | 10 ++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index e6a4c38c..62bb3581 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -9,7 +9,7 @@ lotus.logger.setLevel("DEBUG") -@pytest.fixture +@pytest.fixture(scope="session") def setup_gpt_models(): # Setup GPT models gpt_4o_mini = LM(model="gpt-4o-mini") @@ -17,6 +17,20 @@ def setup_gpt_models(): return gpt_4o_mini, gpt_4o +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_gpt_models): + yield # this runs the test + gpt_4o_mini, gpt_4o = setup_gpt_models + print("\nUsage stats for gpt-4o-mini after test:") + gpt_4o_mini.print_total_usage() + print("\nUsage stats for gpt-4o after test:") + gpt_4o.print_total_usage() + + # Reset stats + gpt_4o_mini.reset_stats() + gpt_4o.reset_stats() + + def test_filter_operation(setup_gpt_models): gpt_4o_mini, _ = setup_gpt_models lotus.settings.configure(lm=gpt_4o_mini) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 149c0714..f4e5bdb2 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,12 +1,12 @@ from typing import Any import numpy as np -from litellm import batch_completion +from litellm import batch_completion, completion_cost from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter from tokenizers import Tokenizer -from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade +from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade class LM: @@ -25,6 +25,8 @@ def __init__( self.tokenizer = tokenizer self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + self.stats: LMStats = LMStats() + def __call__( self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any] ) -> LMOutput: @@ -42,8 +44,17 @@ def __call__( [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None ) + for resp in responses: + self._update_stats(resp) + return LMOutput(outputs=outputs, logprobs=logprobs) + def _update_stats(self, response: ModelResponse): + self.stats.total_usage.prompt_tokens += response.usage.prompt_tokens + self.stats.total_usage.completion_tokens += response.usage.completion_tokens + self.stats.total_usage.total_tokens += response.usage.total_tokens + self.stats.total_usage.total_cost += completion_cost(completion_response=response) + def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: all_kwargs = {**self.kwargs, **kwargs} if all_kwargs.get("logprobs", False): @@ -121,3 +132,14 @@ def count_tokens(self, messages: list[dict[str, str]] | str) -> int: model=self.model, messages=messages, ) + + def print_total_usage(self): + print(f"Total cost: ${self.stats.total_usage.total_cost:.6f}") + print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}") + print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}") + print(f"Total tokens: {self.stats.total_usage.total_tokens}") + + def reset_stats(self): + self.stats = LMStats( + total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0) + ) diff --git a/lotus/types.py b/lotus/types.py index 6e11f93a..d6b3443e 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -69,3 +69,13 @@ class LogprobsForCascade(BaseModel): class LogprobsForFilterCascade(LogprobsForCascade): true_probs: list[float] + + +class LMStats(BaseModel): + class TotalUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost: float = 0.0 + + total_usage: TotalUsage = TotalUsage() From b05784a97cc9c3bcbad0d62e8ec81f7c3a04fb46 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sat, 2 Nov 2024 19:23:13 -0700 Subject: [PATCH 14/28] Fix mypy --- lotus/models/lm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index f4e5bdb2..7f9fe5d0 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -50,6 +50,9 @@ def __call__( return LMOutput(outputs=outputs, logprobs=logprobs) def _update_stats(self, response: ModelResponse): + if not hasattr(response, "usage"): + return + self.stats.total_usage.prompt_tokens += response.usage.prompt_tokens self.stats.total_usage.completion_tokens += response.usage.completion_tokens self.stats.total_usage.total_tokens += response.usage.total_tokens From 77271fcdfc9ac63cb445935d76465a1f390fca57 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 22:48:49 -0700 Subject: [PATCH 15/28] Add ollama tests --- .github/tests/lm_tests.py | 230 ++++++++++++++++++++---------------- .github/workflows/tests.yml | 77 +++++++++++- lotus/models/lm.py | 11 +- 3 files changed, 211 insertions(+), 107 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 62bb3581..336b2009 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -1,3 +1,4 @@ +import os import pandas as pd import pytest from tokenizers import Tokenizer @@ -5,35 +6,54 @@ import lotus from lotus.models import LM +################################################################################ +# Setup +################################################################################ # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +MODEL_NAME_TO_ENABLED = { + "gpt-4o-mini": ENABLE_OPENAI_TESTS, + "gpt-4o": ENABLE_OPENAI_TESTS, + "ollama/llama3.2": ENABLE_OLLAMA_TESTS +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + +def get_enabled(*candidate_models: tuple) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] + @pytest.fixture(scope="session") -def setup_gpt_models(): - # Setup GPT models - gpt_4o_mini = LM(model="gpt-4o-mini") - gpt_4o = LM(model="gpt-4o") - return gpt_4o_mini, gpt_4o +def setup_models(): + models = {} + + for model_path in ENABLED_MODEL_NAMES: + models[model_path] = LM(model=model_path) + + return models @pytest.fixture(autouse=True) -def print_usage_after_each_test(setup_gpt_models): +def print_usage_after_each_test(setup_models): yield # this runs the test - gpt_4o_mini, gpt_4o = setup_gpt_models - print("\nUsage stats for gpt-4o-mini after test:") - gpt_4o_mini.print_total_usage() - print("\nUsage stats for gpt-4o after test:") - gpt_4o.print_total_usage() - - # Reset stats - gpt_4o_mini.reset_stats() - gpt_4o.reset_stats() + models = setup_models + for model_name, model in models.items(): + print(f"\nUsage stats for {model_name} after test:") + model.print_total_usage() + model.reset_stats() -def test_filter_operation(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) +################################################################################ +# Standard tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_filter_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) # Test filter operation on an easy dataframe data = {"Text": ["I am really excited to go to class today!", "I am very sad"]} @@ -44,10 +64,84 @@ def test_filter_operation(setup_gpt_models): expected_df = pd.DataFrame({"Text": ["I am really excited to go to class today!"]}) assert filtered_df.equals(expected_df) +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_top_k(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = { + "Text": [ + "Lionel Messi is a good soccer player", + "Michael Jordan is a good basketball player", + "Steph Curry is a good basketball player", + "Tom Brady is a good football player", + ] + } + df = pd.DataFrame(data) + user_instruction = "Which {Text} is most related to basketball?" + top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"]) + + strategies = ["quick", "heap", "naive"] + for strategy in strategies: + sorted_df = df.sem_topk(user_instruction, K=2, strategy=strategy) + + top_2_actual = set(sorted_df["Text"].values) + assert top_2_expected == top_2_actual + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_join(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data1 = {"School": ["UC Berkeley", "Stanford"]} + data2 = {"School Type": ["Public School", "Private School"]} + + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2) + join_instruction = "{School} is a {School Type}" + joined_df = df1.sem_join(df2, join_instruction) + joined_pairs = set(zip(joined_df["School"], joined_df["School Type"])) + expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")]) + assert joined_pairs == expected_pairs + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_map_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = {"School": ["UC Berkeley", "Carnegie Mellon"]} + df = pd.DataFrame(data) + examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} + examples_df = pd.DataFrame(examples) + user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." + df = df.sem_map(user_instruction, examples=examples_df, suffix="State") + + pairs = set(zip(df["School"], df["State"])) + expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) + assert pairs == expected_pairs + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_agg_then_map(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} + df = pd.DataFrame(data) + agg_instruction = "What is the most common name in {Text}?" + agg_df = df.sem_agg(agg_instruction, suffix="draft_output") + map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" + cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") + assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john" -def test_filter_cascade(setup_gpt_models): - gpt_4o_mini, gpt_4o = setup_gpt_models - lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) +################################################################################ +# Cascade tests +################################################################################ +@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled") +def test_filter_cascade(setup_models): + models = setup_models + lotus.settings.configure(lm=models["gpt-4o"], helper_lm=models["gpt-4o-mini"]) data = { "Text": [ @@ -113,50 +207,10 @@ def test_filter_cascade(setup_gpt_models): assert "I am very sad" not in filtered_df["Text"].values assert stats["filters_resolved_by_helper_model"] > 0, stats - -def test_top_k(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) - - data = { - "Text": [ - "Lionel Messi is a good soccer player", - "Michael Jordan is a good basketball player", - "Steph Curry is a good basketball player", - "Tom Brady is a good football player", - ] - } - df = pd.DataFrame(data) - user_instruction = "Which {Text} is most related to basketball?" - top_2_expected = set(["Michael Jordan is a good basketball player", "Steph Curry is a good basketball player"]) - - strategies = ["quick", "heap", "naive"] - for strategy in strategies: - sorted_df = df.sem_topk(user_instruction, K=2, strategy=strategy) - - top_2_actual = set(sorted_df["Text"].values) - assert top_2_expected == top_2_actual - - -def test_join(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) - - data1 = {"School": ["UC Berkeley", "Stanford"]} - data2 = {"School Type": ["Public School", "Private School"]} - - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2) - join_instruction = "{School} is a {School Type}" - joined_df = df1.sem_join(df2, join_instruction) - joined_pairs = set(zip(joined_df["School"], joined_df["School Type"])) - expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")]) - assert joined_pairs == expected_pairs - - -def test_join_cascade(setup_gpt_models): - gpt_4o_mini, gpt_4o = setup_gpt_models - lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) +@pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled") +def test_join_cascade(setup_models): + models = setup_models + lotus.settings.configure(lm=models["gpt-4o"], helper_lm=models["gpt-4o-mini"]) data1 = {"School": ["UC Berkeley", "Stanford"]} data2 = {"School Type": ["Public School", "Private School"]} @@ -180,44 +234,20 @@ def test_join_cascade(setup_gpt_models): assert stats["filters_resolved_by_large_model"] == 4, stats assert stats["filters_resolved_by_helper_model"] == 0, stats +################################################################################ +# Token counting tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) +def test_count_tokens(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) -def test_map_fewshot(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) - - data = {"School": ["UC Berkeley", "Carnegie Mellon"]} - df = pd.DataFrame(data) - examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} - examples_df = pd.DataFrame(examples) - user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." - df = df.sem_map(user_instruction, examples=examples_df, suffix="State") - - pairs = set(zip(df["School"], df["State"])) - expected_pairs = set([("UC Berkeley", "CA"), ("Carnegie Mellon", "PA")]) - assert pairs == expected_pairs - - -def test_agg_then_map(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) - - data = {"Text": ["My name is John", "My name is Jane", "My name is John"]} - df = pd.DataFrame(data) - agg_instruction = "What is the most common name in {Text}?" - agg_df = df.sem_agg(agg_instruction, suffix="draft_output") - map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" - cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") - assert cleaned_df["final_output"].values[0] == "John" - - -def test_count_tokens(setup_gpt_models): - gpt_4o_mini, _ = setup_gpt_models - lotus.settings.configure(lm=gpt_4o_mini) - - tokens = gpt_4o_mini.count_tokens("Hello, world!") - assert gpt_4o_mini.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens + tokens = lm.count_tokens("Hello, world!") + assert lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens assert tokens < 100 + +def test_custom_tokenizer(): custom_tokenizer = Tokenizer.from_pretrained("gpt2") custom_lm = LM(model="doesn't matter", tokenizer=custom_tokenizer) tokens = custom_lm.count_tokens("Hello, world!") diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1991af16..2527c886 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -54,8 +54,8 @@ jobs: - name: Run mypy run: mypy lotus/ - test: - name: Python Tests + openai_lm_test: + name: OpenAI Language Model Tests runs-on: ubuntu-latest timeout-minutes: 5 @@ -78,9 +78,76 @@ jobs: - name: Set OpenAI API Key run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV - - name: Run Python tests + - name: Run LM tests env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + run: pytest .github/tests/lm_tests.py + + ollama_lm_test: + name: Ollama Language Model Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Start Ollama container run: | - pytest .github/tests/lm_tests.py - pytest .github/tests/rm_tests.py \ No newline at end of file + docker pull ollama/ollama:latest + docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama + # Wait for Ollama server to be ready + timeout=30 + while ! curl -s http://localhost:11434/ >/dev/null; do + if [ $timeout -le 0 ]; then + echo "Timed out waiting for Ollama server" + exit 1 + fi + echo "Waiting for Ollama server to be ready..." + sleep 1 + timeout=$((timeout - 1)) + done + docker exec $(docker ps -q) ollama run llama3.2 + + - name: Run LM tests + env: + ENABLE_OLLAMA_TESTS: true + run: pytest .github/tests/lm_tests.py + + + rm_test: + name: Retrieval Model Tests + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Run RM tests + run: pytest .github/tests/rm_tests.py \ No newline at end of file diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 7f9fe5d0..429206c4 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,11 +1,13 @@ from typing import Any import numpy as np +import litellm from litellm import batch_completion, completion_cost from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter from tokenizers import Tokenizer +import lotus from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade @@ -38,6 +40,7 @@ def __call__( max_tokens=kwargs_for_batch.get("max_tokens"), top_logprobs=kwargs_for_batch.get("top_logprobs"), logprobs=kwargs_for_batch.get("logprobs"), + drop_params=True, ) outputs = [self._get_top_choice(resp) for resp in responses] logprobs = ( @@ -56,7 +59,12 @@ def _update_stats(self, response: ModelResponse): self.stats.total_usage.prompt_tokens += response.usage.prompt_tokens self.stats.total_usage.completion_tokens += response.usage.completion_tokens self.stats.total_usage.total_tokens += response.usage.total_tokens - self.stats.total_usage.total_cost += completion_cost(completion_response=response) + + try: + self.stats.total_usage.total_cost += completion_cost(completion_response=response) + except litellm.exceptions.NotFoundError as e: + # Sometimes the model's pricing information is not available + lotus.logger.debug(f"Error updating completion cost: {e}") def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: all_kwargs = {**self.kwargs, **kwargs} @@ -129,7 +137,6 @@ def count_tokens(self, messages: list[dict[str, str]] | str) -> int: if self.tokenizer: custom_tokenizer = dict(type="huggingface_tokenizer", tokenizer=self.tokenizer) - # Pass values directly rather than using kwargs dict to preserve typing return token_counter( custom_tokenizer=custom_tokenizer, model=self.model, From 1d8aa5d3c5b5d070197139a676d0d0f84a743af2 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 22:53:05 -0700 Subject: [PATCH 16/28] Reformat --- .github/tests/lm_tests.py | 9 ++++++++- lotus/models/lm.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 336b2009..662dc94f 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -1,4 +1,5 @@ import os + import pandas as pd import pytest from tokenizers import Tokenizer @@ -19,10 +20,11 @@ MODEL_NAME_TO_ENABLED = { "gpt-4o-mini": ENABLE_OPENAI_TESTS, "gpt-4o": ENABLE_OPENAI_TESTS, - "ollama/llama3.2": ENABLE_OLLAMA_TESTS + "ollama/llama3.2": ENABLE_OLLAMA_TESTS, } ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + def get_enabled(*candidate_models: tuple) -> list[str]: return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] @@ -64,6 +66,7 @@ def test_filter_operation(setup_models, model): expected_df = pd.DataFrame({"Text": ["I am really excited to go to class today!"]}) assert filtered_df.equals(expected_df) + @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_top_k(setup_models, model): lm = setup_models[model] @@ -105,6 +108,7 @@ def test_join(setup_models, model): expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")]) assert joined_pairs == expected_pairs + @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.2")) def test_map_fewshot(setup_models, model): lm = setup_models[model] @@ -135,6 +139,7 @@ def test_agg_then_map(setup_models, model): cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john" + ################################################################################ # Cascade tests ################################################################################ @@ -207,6 +212,7 @@ def test_filter_cascade(setup_models): assert "I am very sad" not in filtered_df["Text"].values assert stats["filters_resolved_by_helper_model"] > 0, stats + @pytest.mark.skipif(not ENABLE_OPENAI_TESTS, reason="Skipping test because OpenAI tests are not enabled") def test_join_cascade(setup_models): models = setup_models @@ -234,6 +240,7 @@ def test_join_cascade(setup_models): assert stats["filters_resolved_by_large_model"] == 4, stats assert stats["filters_resolved_by_helper_model"] == 0, stats + ################################################################################ # Token counting tests ################################################################################ diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 429206c4..f707e1ec 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -1,7 +1,7 @@ from typing import Any -import numpy as np import litellm +import numpy as np from litellm import batch_completion, completion_cost from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter From 84ff6e5b012db6282dc4879c5ade766da3a1f4d4 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 23:19:12 -0700 Subject: [PATCH 17/28] Update docs --- CONTRIBUTING.md | 20 +++++++++----------- README.md | 4 ++-- docs/quickstart.rst | 8 ++++---- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 02cac94f..a92dcf3c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,22 +14,20 @@ pip install -r requirements-dev.txt ## Dev Flow After making your changes, please make a PR to get your changes merged upstream. -## Running vLLM Models -To use vLLM for model serving, you just need to make an OpenAI compatible vLLM server. Then, the `OpenAIModel` class can be used to point to the server. See an example below. +## Running Models +To run a model, you can use the `LM` class in `lotus.models.LM`. We use the `litellm` library to interface with the model. +This allows you to use any model provider that is supported by `litellm`. -Create the server +Here's an example of creating an `LM` object for `gpt-4o` ``` -python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --port 8000 --tensor-parallel-size 8 +from lotus.models import LM +lm = LM(model="gpt-4o") ``` -In LOTUS, you should instantiate your model as follows +Here's an example of creating an `LM` object to use `llama3.2` on Ollama ``` -from lotus.models import OpenAIModel -lm = OpenAIModel( - model="meta-llama/Meta-Llama-3.1-70B-Instruct", - api_base="http://localhost:8000/v1", - provider="vllm", -) +from lotus.models import LM +lm = LM(model="ollama/llama3.2") ``` ## Helpful Examples diff --git a/README.md b/README.md index acfd61e4..2c57c51c 100644 --- a/README.md +++ b/README.md @@ -45,10 +45,10 @@ If you're already familiar with Pandas, getting started will be a breeze! Below ```python import pandas as pd import lotus -from lotus.models import OpenAIModel +from lotus.models import LM # configure the LM, and remember to export your API key -lm = OpenAIModel() +lm = LM() lotus.settings.configure(lm=lm) # create dataframes with course names and skills diff --git a/docs/quickstart.rst b/docs/quickstart.rst index fe7e99d2..e194177f 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -50,10 +50,10 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg import pandas as pd import lotus - from lotus.models import E5Model, OpenAIModel + from lotus.models import E5Model, LM # Configure models for LOTUS - lm = OpenAIModel(max_tokens=512) + lm = LM() rm = E5Model() lotus.settings.configure(lm=lm, rm=rm) @@ -90,7 +90,7 @@ If we wanted the challenge of taking courses with a high workload, we can also u .. code-block:: python - top_2_hardest = df.sem_topk("What {Description} indicates the highest workload?", 2) + top_2_hardest = df.sem_topk("What {Description} indicates the highest workload?", K=2) LOTUS's semantic join operator can be used to join two dataframes based on a predicate. Suppose we had a second dataframe containing skills we wanted to get better at (SQL and Chip Design in our case). @@ -113,7 +113,7 @@ Let's create a semantic index on the course description column and then search f # Create a semantic index on the description column and save it to the index_dir directory df = df.sem_index("Description", "index_dir") - top_conv_df = df.sem_search("Description", "Convolutional Neural Network", 1) + top_conv_df = df.sem_search("Description", "Convolutional Neural Network", K=1) Another useful operator is the semantic map operator. Let's see how it can be used to get some next topics to explore for each class. Additionally, let's provide some examples to the model that can be used for demonstrations. From 23e291bca798b82a840d247ccf368611172db8f0 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 23:22:44 -0700 Subject: [PATCH 18/28] Bump version --- docs/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 3b4508c1..a9bf8b8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ project = "LOTUS" copyright = "2024, Liana Patel, Siddharth Jha, Carlos Guestrin, Matei Zaharia" author = "Liana Patel, Siddharth Jha, Carlos Guestrin, Matei Zaharia" -release = "0.2.2" +release = "0.3.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 3db30d79..6a170ce0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lotus-ai" -version = "0.2.2" +version = "0.3.0" description = "lotus" readme = "README.md" authors = [ From fc67fa746c997d9ad05e4f16032654250de699e0 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 23:52:12 -0700 Subject: [PATCH 19/28] Fix pre-commit --- .pre-commit-config.yaml | 8 ++++++++ lotus/models/colbertv2_model.py | 4 ++-- lotus/models/e5_model.py | 10 +++++----- lotus/models/rm.py | 8 ++++---- lotus/sem_ops/cascade_utils.py | 4 ++-- lotus/sem_ops/sem_filter.py | 2 +- 6 files changed, 22 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38997e65..0964e87f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,11 @@ repos: hooks: - id: mypy args: ["--config-file", "mypy.ini"] + additional_dependencies: + - types-setuptools + - litellm>=1.51.0 + - numpy>=1.25.0 + - pandas>=2.0.0 + - sentence-transformers>=3.0.1 + - tiktoken>=0.7.0 + - tqdm>=4.66.4 diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_model.py index 595b970c..51bcd7cb 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_model.py @@ -44,12 +44,12 @@ def load_index(self, index_dir: str) -> None: with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "rb") as fp: self.docs = pickle.load(fp) - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float_]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: raise NotImplementedError("This method is not implemented for ColBERTv2Model") def __call__( self, - queries: str | list[str] | NDArray[np.float_], + queries: str | list[str] | NDArray[np.float64], k: int, **kwargs: dict[str, Any], ) -> tuple[list[list[float]], list[list[int]]]: diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py index 194d3a5e..d29c7ddf 100644 --- a/lotus/models/e5_model.py +++ b/lotus/models/e5_model.py @@ -26,7 +26,7 @@ def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs} self.batch_size: int = 100 - self.vecs: NDArray[np.float_] | None = None + self.vecs: NDArray[np.float64] | None = None import faiss @@ -46,7 +46,7 @@ def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.T last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> NDArray[np.float_]: + def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> NDArray[np.float64]: """Run the embedding model. Args: @@ -112,15 +112,15 @@ def load_index(self, index_dir: str) -> None: self.vecs = pickle.load(fp) @classmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float_]: + def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: with open(f"{index_dir}/vecs", "rb") as fp: - vecs: NDArray[np.float_] = pickle.load(fp) + vecs: NDArray[np.float64] = pickle.load(fp) return vecs[ids] def __call__( self, - queries: str | list[str] | NDArray[np.float_], + queries: str | list[str] | NDArray[np.float64], k: int, **kwargs: dict[str, Any], ) -> tuple[list[list[float]], list[list[int]]]: diff --git a/lotus/models/rm.py b/lotus/models/rm.py index 85cf1af2..e7cb8ba7 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -31,7 +31,7 @@ def load_index(self, index_dir: str) -> None: pass @abstractmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float_]: + def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Get the vectors from the index. Args: @@ -39,7 +39,7 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl ids (list[int]): The ids of the vectors to retrieve Returns: - NDArray[np.float_]: The vectors matching the specified ids. + NDArray[np.float64]: The vectors matching the specified ids. """ pass @@ -47,14 +47,14 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl @abstractmethod def __call__( self, - queries: str | list[str] | NDArray[np.float_], + queries: str | list[str] | NDArray[np.float64], k: int, **kwargs: dict[str, Any], ) -> tuple[list[list[float]], list[list[int]]]: """Run top-k search on the index. Args: - queries (str | list[str] | NDArray[np.float_]): Either a query or a list of queries or a 2D FP32 array. + queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. k (int): The k to use for top-k search. **kwargs (dict[str, Any]): Additional keyword arguments. diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 2bee2c77..a588f0a6 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -7,7 +7,7 @@ def importance_sampling( proxy_scores: list[float], sample_percentage: float, -) -> tuple[NDArray[np.int_], NDArray[np.float_]]: +) -> tuple[NDArray[np.int_], NDArray[np.float64]]: """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) @@ -33,7 +33,7 @@ def calibrate_llm_logprobs(true_probs: list[float]) -> list[float]: def learn_cascade_thresholds( proxy_scores: list[float], oracle_outputs: list[bool], - sample_correction_factors: NDArray[np.float_], + sample_correction_factors: NDArray[np.float64], recall_target: float, precision_target: float, delta: float, diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 606b8da7..00be6789 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -68,7 +68,7 @@ def learn_filter_cascade_thresholds( precision_target: float, delta: float, helper_true_probs: list[float], - sample_correction_factors: NDArray[np.float_], + sample_correction_factors: NDArray[np.float64], examples_df_txt: list[str] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, From eeff8bf819eb5f7fab2331b04ed1cf81b56b2ce7 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sat, 2 Nov 2024 23:54:13 -0700 Subject: [PATCH 20/28] Use int64 instead of int_ --- lotus/sem_ops/cascade_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index a588f0a6..088ba9be 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -7,7 +7,7 @@ def importance_sampling( proxy_scores: list[float], sample_percentage: float, -) -> tuple[NDArray[np.int_], NDArray[np.float64]]: +) -> tuple[NDArray[np.int64], NDArray[np.float64]]: """Uses importance sampling and returns the list of indices from which to learn cascade thresholds.""" w = np.sqrt(proxy_scores) From 7159dd64e2f1ad736f6b5a30f195172e0bcfd16b Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sun, 3 Nov 2024 00:16:56 -0700 Subject: [PATCH 21/28] Fix bug --- CONTRIBUTING.md | 7 +++++++ lotus/models/lm.py | 23 +++++++++-------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a92dcf3c..0a08df16 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,6 +9,7 @@ conda activate lotus git clone git@github.com:stanford-futuredata/lotus.git pip install -e . pip install -r requirements-dev.txt +pre-commit install ``` ## Dev Flow @@ -30,5 +31,11 @@ from lotus.models import LM lm = LM(model="ollama/llama3.2") ``` +Here's an example of creating an `LM` object to use `Meta-Llama-3-8B-Instruct` on vLLM +``` +from lotus.models import LM +lm = LM(model='hosted_vllm/meta-llama/Meta-Llama-3-8B-Instruct', api_base='http://localhost:8000/v1') +``` + ## Helpful Examples For helpful examples of LOTUS operators, please refer to the `examples` folder, as well as the documentation. \ No newline at end of file diff --git a/lotus/models/lm.py b/lotus/models/lm.py index f707e1ec..7566e111 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -32,20 +32,21 @@ def __init__( def __call__( self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any] ) -> LMOutput: - kwargs_for_batch = self._format_batch_kwargs(kwargs) + all_kwargs = {**self.kwargs, **kwargs} + + # Set top_logprobs if logprobs requested + if all_kwargs.get("logprobs", False): + all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) + responses: list[ModelResponse] = batch_completion( self.model, messages, - temperature=kwargs_for_batch.get("temperature"), - max_tokens=kwargs_for_batch.get("max_tokens"), - top_logprobs=kwargs_for_batch.get("top_logprobs"), - logprobs=kwargs_for_batch.get("logprobs"), drop_params=True, + **all_kwargs, # type: ignore ) + outputs = [self._get_top_choice(resp) for resp in responses] - logprobs = ( - [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None - ) + logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if all_kwargs.get("logprobs") else None for resp in responses: self._update_stats(resp) @@ -66,12 +67,6 @@ def _update_stats(self, response: ModelResponse): # Sometimes the model's pricing information is not available lotus.logger.debug(f"Error updating completion cost: {e}") - def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: - all_kwargs = {**self.kwargs, **kwargs} - if all_kwargs.get("logprobs", False): - all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) - return {k: v for k, v in all_kwargs.items() if k in ["temperature", "max_tokens", "top_logprobs", "logprobs"]} - def _get_top_choice(self, response: ModelResponse) -> str: choice = response.choices[0] assert isinstance(choice, Choices) From 70f55103f95d351c0bca5adabdefddbef0efb1fe Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sun, 3 Nov 2024 00:28:28 -0700 Subject: [PATCH 22/28] Better error handling --- lotus/models/lm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 7566e111..b525a63a 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -5,6 +5,7 @@ from litellm import batch_completion, completion_cost from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse from litellm.utils import token_counter +from openai import OpenAIError from tokenizers import Tokenizer import lotus @@ -45,6 +46,11 @@ def __call__( **all_kwargs, # type: ignore ) + # throw errors, if any + for resp in responses: + if isinstance(resp, OpenAIError): + raise resp + outputs = [self._get_top_choice(resp) for resp in responses] logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if all_kwargs.get("logprobs") else None From 1e1eb4cbca9110ca2ffff934c5fcc95376c25d0c Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sun, 3 Nov 2024 00:40:52 -0700 Subject: [PATCH 23/28] Minor fix --- lotus/models/lm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index b525a63a..42a1dfb0 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -30,9 +30,7 @@ def __init__( self.stats: LMStats = LMStats() - def __call__( - self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any] - ) -> LMOutput: + def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} # Set top_logprobs if logprobs requested From 51071f22a081938c68ee76bc4dd29980012fa106 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sun, 3 Nov 2024 13:11:12 -0800 Subject: [PATCH 24/28] Add max batch size --- lotus/models/lm.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 42a1dfb0..ee03a444 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -19,12 +19,14 @@ def __init__( temperature: float = 0.0, max_ctx_len: int = 128000, max_tokens: int = 512, + max_batch_size: int = 64, tokenizer: Tokenizer | None = None, **kwargs: dict[str, Any], ): self.model = model self.max_ctx_len = max_ctx_len self.max_tokens = max_tokens + self.max_batch_size = max_batch_size self.tokenizer = tokenizer self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -37,22 +39,28 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any if all_kwargs.get("logprobs", False): all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10) - responses: list[ModelResponse] = batch_completion( - self.model, - messages, - drop_params=True, - **all_kwargs, # type: ignore - ) + all_responses: list[ModelResponse] = [] + for i in range(0, len(messages), self.max_batch_size): + batch = messages[i : i + self.max_batch_size] + responses: list[ModelResponse] = batch_completion( + self.model, + batch, + drop_params=True, + **all_kwargs, # type: ignore + ) + all_responses.extend(responses) # throw errors, if any - for resp in responses: + for resp in all_responses: if isinstance(resp, OpenAIError): raise resp - outputs = [self._get_top_choice(resp) for resp in responses] - logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if all_kwargs.get("logprobs") else None + outputs = [self._get_top_choice(resp) for resp in all_responses] + logprobs = ( + [self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None + ) - for resp in responses: + for resp in all_responses: self._update_stats(resp) return LMOutput(outputs=outputs, logprobs=logprobs) From c79c7b44365f419efefa625ab9df6e6ada0c9267 Mon Sep 17 00:00:00 2001 From: sidjha1 Date: Sun, 3 Nov 2024 15:40:37 -0800 Subject: [PATCH 25/28] Small type change --- .github/tests/lm_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 662dc94f..4ebc02a1 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -25,7 +25,7 @@ ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) -def get_enabled(*candidate_models: tuple) -> list[str]: +def get_enabled(*candidate_models: str) -> list[str]: return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] From d182df86f0afede14c9363674450925e17c835bf Mon Sep 17 00:00:00 2001 From: Sid Jha <45739834+sidjha1@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:30:41 -0800 Subject: [PATCH 26/28] Apply RM and Reranker refactor --- .github/tests/rm_tests.py | 152 ++++++++++++------ .github/workflows/tests.yml | 9 +- docs/quickstart.rst | 4 +- examples/op_examples/agg.py | 4 +- examples/op_examples/cluster.py | 4 +- examples/op_examples/dedup.py | 4 +- examples/op_examples/partition.py | 4 +- examples/op_examples/search.py | 6 +- examples/op_examples/sim_join.py | 5 +- lotus/models/__init__.py | 14 +- .../{colbertv2_model.py => colbertv2_rm.py} | 38 ++--- lotus/models/cross_encoder_model.py | 29 ---- lotus/models/cross_encoder_reranker.py | 28 ++++ lotus/models/e5_model.py | 141 ---------------- lotus/models/faiss_rm.py | 62 +++++++ lotus/models/litellm_rm.py | 29 ++++ lotus/models/reranker.py | 8 +- lotus/models/rm.py | 14 +- lotus/models/sentence_transformers_rm.py | 36 +++++ lotus/sem_ops/sem_search.py | 10 +- lotus/sem_ops/sem_sim_join.py | 11 +- lotus/sem_ops/sem_topk.py | 28 ++-- lotus/types.py | 62 ++++--- 23 files changed, 390 insertions(+), 312 deletions(-) rename lotus/models/{colbertv2_model.py => colbertv2_rm.py} (63%) delete mode 100644 lotus/models/cross_encoder_model.py create mode 100644 lotus/models/cross_encoder_reranker.py delete mode 100644 lotus/models/e5_model.py create mode 100644 lotus/models/faiss_rm.py create mode 100644 lotus/models/litellm_rm.py create mode 100644 lotus/models/sentence_transformers_rm.py diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 3940944a..2c00e116 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -1,23 +1,55 @@ +import os + import pandas as pd import pytest import lotus -from lotus.models import CrossEncoderModel, E5Model +from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM +################################################################################ +# Setup +################################################################################ # Set logger level to DEBUG lotus.logger.setLevel("DEBUG") +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_LOCAL_TESTS = os.getenv("ENABLE_LOCAL_TESTS", "false").lower() == "true" + +# TODO: Add colbertv2 tests +MODEL_NAME_TO_ENABLED = { + "intfloat/e5-small-v2": ENABLE_LOCAL_TESTS, + "mixedbread-ai/mxbai-rerank-xsmall-v1": ENABLE_LOCAL_TESTS, + "text-embedding-3-small": ENABLE_OPENAI_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": SentenceTransformersRM, + "mixedbread-ai/mxbai-rerank-xsmall-v1": CrossEncoderReranker, + "text-embedding-3-small": LiteLLMRM, +} + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] -@pytest.fixture + +@pytest.fixture(scope="session") def setup_models(): - # Set up embedder and reranker model - rm = E5Model(model="intfloat/e5-small-v2") - reranker = CrossEncoderModel(model="mixedbread-ai/mxbai-rerank-xsmall-v1") - return rm, reranker + models = {} + + for model_name in ENABLED_MODEL_NAMES: + models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + return models -def test_cluster_by(setup_models): - rm, _ = setup_models +################################################################################ +# RM Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_cluster_by(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -44,8 +76,9 @@ def test_cluster_by(setup_models): assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups -def test_search_rm_only(setup_models): - rm, _ = setup_models +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_search_rm_only(setup_models, model): + rm = setup_models[model] lotus.settings.configure(rm=rm) data = { @@ -62,43 +95,35 @@ def test_search_rm_only(setup_models): assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] -def test_search_reranker_only(setup_models): - _, reranker = setup_models - lotus.settings.configure(reranker=reranker) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_sim_join(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) - data = { + data1 = { "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", + "History of the Atlantic World", + "Riemannian Geometry", ] } - df = pd.DataFrame(data) - df = df.sem_search("Course Name", "Optimization", n_rerank=2) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] + data2 = {"Skill": ["Math", "History"]} -def test_search(setup_models): - rm, reranker = setup_models - lotus.settings.configure(rm=rm, reranker=reranker) - - data = { - "Course Name": [ - "Probability and Random Processes", - "Cooking", - "Food Sciences", - "Optimization Methods in Engineering", - ] - } - df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index_dir") - df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) - assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") + joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) + joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) + expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} + assert joined_pairs == expected_pairs, joined_pairs +# TODO: threshold is hardcoded for intfloat/e5-small-v2 +@pytest.mark.skipif( + "intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES, + reason="Skipping test because intfloat/e5-small-v2 is not enabled", +) def test_dedup(setup_models): - rm, _ = setup_models + rm = setup_models["intfloat/e5-small-v2"] lotus.settings.configure(rm=rm) data = { "Text": [ @@ -117,22 +142,47 @@ def test_dedup(setup_models): assert "Probability" in kept[1], kept -def test_sim_join(setup_models): - rm, _ = setup_models - lotus.settings.configure(rm=rm) +################################################################################ +# Reranker Only Tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("mixedbread-ai/mxbai-rerank-xsmall-v1")) +def test_search_reranker_only(setup_models, model): + reranker = setup_models[model] + lotus.settings.configure(reranker=reranker) - data1 = { + data = { "Course Name": [ - "History of the Atlantic World", - "Riemannian Geometry", + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", ] } + df = pd.DataFrame(data) + df = df.sem_search("Course Name", "Optimization", n_rerank=2) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"] - data2 = {"Skill": ["Math", "History"]} - df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") - joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) - joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) - expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} - assert joined_pairs == expected_pairs, joined_pairs +################################################################################ +# Combined Tests +################################################################################ +# TODO: Figure out how to parameterize pairs of models +@pytest.mark.skipif(not ENABLE_LOCAL_TESTS, reason="Skipping test because local tests are not enabled") +def test_search(setup_models): + models = setup_models + rm = models["intfloat/e5-small-v2"] + reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"] + lotus.settings.configure(rm=rm, reranker=reranker) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2527c886..07a9f3ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -149,5 +149,12 @@ jobs: pip install -e . pip install pytest + - name: Set OpenAI API Key + run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV + - name: Run RM tests - run: pytest .github/tests/rm_tests.py \ No newline at end of file + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + ENABLE_LOCAL_TESTS: true + run: pytest .github/tests/rm_tests.py diff --git a/docs/quickstart.rst b/docs/quickstart.rst index e194177f..2a9f2761 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -50,11 +50,11 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg import pandas as pd import lotus - from lotus.models import E5Model, LM + from lotus.models import SentenceTransformersRM, LM # Configure models for LOTUS lm = LM() - rm = E5Model() + rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index 6f6e14b0..6ad9356f 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM() -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 2d7af6f1..9c6697ad 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM() -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 5d21087f..8c89aebd 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -1,9 +1,9 @@ import pandas as pd import lotus -from lotus.models import E5Model +from lotus.models import SentenceTransformersRM -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(rm=rm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index 91fa185b..c1c7174e 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -1,10 +1,10 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, SentenceTransformersRM lm = LM(max_tokens=2048) -rm = E5Model() +rm = SentenceTransformersRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index 21c7fb5e..60f04190 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -1,11 +1,11 @@ import pandas as pd import lotus -from lotus.models import LM, CrossEncoderModel, E5Model +from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM lm = LM() -rm = E5Model() -reranker = CrossEncoderModel() +rm = SentenceTransformersRM() +reranker = CrossEncoderReranker() lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) data = { diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index beaea582..200a7c43 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -1,10 +1,11 @@ import pandas as pd import lotus -from lotus.models import LM, E5Model +from lotus.models import LM, LiteLLMRM lm = LM() -rm = E5Model() +# rm = SentenceTransformersRM() +rm = LiteLLMRM() lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/lotus/models/__init__.py b/lotus/models/__init__.py index 4477c6e2..f88f1dd4 100644 --- a/lotus/models/__init__.py +++ b/lotus/models/__init__.py @@ -1,15 +1,17 @@ -from lotus.models.colbertv2_model import ColBERTv2Model -from lotus.models.cross_encoder_model import CrossEncoderModel -from lotus.models.e5_model import E5Model +from lotus.models.cross_encoder_reranker import CrossEncoderReranker from lotus.models.lm import LM from lotus.models.reranker import Reranker from lotus.models.rm import RM +from lotus.models.litellm_rm import LiteLLMRM +from lotus.models.sentence_transformers_rm import SentenceTransformersRM +from lotus.models.colbertv2_rm import ColBERTv2RM __all__ = [ - "E5Model", - "ColBERTv2Model", - "CrossEncoderModel", + "CrossEncoderReranker", "LM", "RM", "Reranker", + "LiteLLMRM", + "SentenceTransformersRM", + "ColBERTv2RM", ] diff --git a/lotus/models/colbertv2_model.py b/lotus/models/colbertv2_rm.py similarity index 63% rename from lotus/models/colbertv2_model.py rename to lotus/models/colbertv2_rm.py index 51bcd7cb..018af594 100644 --- a/lotus/models/colbertv2_model.py +++ b/lotus/models/colbertv2_rm.py @@ -5,32 +5,28 @@ from numpy.typing import NDArray from lotus.models.rm import RM +from lotus.types import RMOutput +try: + from colbert import Indexer, Searcher + from colbert.infra import ColBERTConfig, Run, RunConfig +except ImportError: + pass -class ColBERTv2Model(RM): - """ColBERTv2 Model""" +class ColBERTv2RM(RM): def __init__(self) -> None: self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} self.index_dir: str | None = None - from colbert import Indexer, Searcher - from colbert.infra import ColBERTConfig, Run, RunConfig - - self.Indexer = Indexer - self.Searcher = Searcher - self.ColBERTConfig = ColBERTConfig - self.Run = Run - self.RunConfig = RunConfig - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: kwargs = {**self.kwargs, **kwargs} checkpoint = "colbert-ir/colbertv2.0" - with self.Run().context(self.RunConfig(nranks=1, experiment="lotus")): - config = self.ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) - indexer = self.Indexer(checkpoint=checkpoint, config=config) + with Run().context(RunConfig(nranks=1, experiment="lotus")): + config = ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) + indexer = Indexer(checkpoint=checkpoint, config=config) indexer.index(name=f"{index_dir}/index", collection=docs, overwrite=True) with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "wb") as fp: @@ -45,25 +41,25 @@ def load_index(self, index_dir: str) -> None: self.docs = pickle.load(fp) def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - raise NotImplementedError("This method is not implemented for ColBERTv2Model") + raise NotImplementedError("This method is not implemented for ColBERTv2RM") def __call__( self, queries: str | list[str] | NDArray[np.float64], - k: int, + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: + ) -> RMOutput: if isinstance(queries, str): queries = [queries] - with self.Run().context(self.RunConfig(experiment="lotus")): - searcher = self.Searcher(index=f"{self.index_dir}/index", collection=self.docs) + with Run().context(RunConfig(experiment="lotus")): + searcher = Searcher(index=f"{self.index_dir}/index", collection=self.docs) # make queries a dict with keys as query ids queries_dict = {i: q for i, q in enumerate(queries)} - all_results = searcher.search_all(queries_dict, k=k).todict() + all_results = searcher.search_all(queries_dict, k=K).todict() indices = [[result[0] for result in all_results[qid]] for qid in all_results.keys()] distances = [[result[2] for result in all_results[qid]] for qid in all_results.keys()] - return distances, indices + return RMOutput(distances=distances, indices=indices) diff --git a/lotus/models/cross_encoder_model.py b/lotus/models/cross_encoder_model.py deleted file mode 100644 index f49aa59a..00000000 --- a/lotus/models/cross_encoder_model.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from sentence_transformers import CrossEncoder - -from lotus.models.reranker import Reranker - - -class CrossEncoderModel(Reranker): - """CrossEncoder reranker model. - - Args: - model (str): The name of the reranker model to use. - device (str): What device to keep the model on. - """ - - def __init__( - self, - model: str = "mixedbread-ai/mxbai-rerank-large-v1", - device: str | None = None, - batch_size: int = 32, - ): - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device: str = device - self.batch_size: int = batch_size - self.model = CrossEncoder(model, device=device) - - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: - results = self.model.rank(query, docs, top_k=k, batch_size=self.batch_size) - return [int(result["corpus_id"]) for result in results] diff --git a/lotus/models/cross_encoder_reranker.py b/lotus/models/cross_encoder_reranker.py new file mode 100644 index 00000000..65827ce2 --- /dev/null +++ b/lotus/models/cross_encoder_reranker.py @@ -0,0 +1,28 @@ +from sentence_transformers import CrossEncoder + +from lotus.models.reranker import Reranker +from lotus.types import RerankerOutput + + +class CrossEncoderReranker(Reranker): + """CrossEncoder reranker model. + + Args: + model (str): The name of the reranker model to use. + device (str): What device to keep the model on. + max_batch_size (int): The maximum batch size to use for the model. + """ + + def __init__( + self, + model: str = "mixedbread-ai/mxbai-rerank-large-v1", + device: str | None = None, + max_batch_size: int = 64, + ): + self.max_batch_size: int = max_batch_size + self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs + + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: + results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size) + indices = [int(result["corpus_id"]) for result in results] + return RerankerOutput(indices=indices) diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py deleted file mode 100644 index d29c7ddf..00000000 --- a/lotus/models/e5_model.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import pickle -from typing import Any - -import numpy as np -import torch -import torch.nn.functional as F -from numpy.typing import NDArray -from tqdm import tqdm -from transformers import AutoModel, AutoTokenizer - -from lotus.models.rm import RM - - -class E5Model(RM): - """E5 retriever model""" - - def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None, **kwargs: dict[str, Any]) -> None: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = device - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.model = AutoModel.from_pretrained(model).to(self.device) - self.faiss_index = None - self.index_dir: str | None = None - self.docs: list[str] | None = None - self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs} - self.batch_size: int = 100 - self.vecs: NDArray[np.float64] | None = None - - import faiss - - self.faiss = faiss - - def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - """Perform average pooling over the last hidden state. - - Args: - last_hidden_states: Hidden states from the model's last layer - attention_mask: Attention mask. - - Returns: - Average pool over the last hidden state. - """ - - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - - def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> NDArray[np.float64]: - """Run the embedding model. - - Args: - docs: A list of documents to embed. - - Returns: - Embeddings of the documents. - """ - - kwargs = {**self.kwargs, **dict(kwargs)} - - batch_size = kwargs.get("batch_size", self.batch_size) - assert isinstance(batch_size, int), "batch_size must be an integer" - - # Calculating the embedding dimension - total_docs = len(docs) - first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True).to(self.device) - embed_dim = self.model(**first_batch).last_hidden_state.size(-1) - - # Pre-allocate a tensor for all embeddings - embeddings = torch.empty((total_docs, embed_dim), device=self.device) - # Processing batches - with torch.inference_mode(): # Slightly faster than torch.no_grad() for inference - for i, batch_start in enumerate(tqdm(range(0, total_docs, batch_size))): - batch = docs[batch_start : batch_start + batch_size] - batch_dict = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) - outputs = self.model(**batch_dict) - batch_embeddings = self.average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) - embeddings[batch_start : batch_start + batch_size] = batch_embeddings - if kwargs["normalize"]: - embeddings = F.normalize(embeddings, p=2, dim=1) - - return embeddings.numpy(force=True) - - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: - # Make index directory - os.makedirs(index_dir, exist_ok=True) - - # Get document embeddings - kwargs = {**self.kwargs, **kwargs} - embeddings = self.embed(docs, **kwargs) - d = embeddings.shape[1] - index = self.faiss.index_factory(d, kwargs["index_type"], self.faiss.METRIC_INNER_PRODUCT) - index.add(embeddings) - - # Store index and documents - self.faiss.write_index(index, f"{index_dir}/index") - with open(f"{index_dir}/docs", "wb") as fp: - pickle.dump(docs, fp) - with open(f"{index_dir}/vecs", "wb") as fp: - pickle.dump(embeddings, fp) - self.faiss_index = index - self.docs = docs - self.index_dir = index_dir - self.vecs = embeddings - - def load_index(self, index_dir: str) -> None: - self.index_dir = index_dir - self.faiss_index = self.faiss.read_index(f"{index_dir}/index") - with open(f"{index_dir}/docs", "rb") as fp: - self.docs = pickle.load(fp) - with open(f"{index_dir}/vecs", "rb") as fp: - self.vecs = pickle.load(fp) - - @classmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - with open(f"{index_dir}/vecs", "rb") as fp: - vecs: NDArray[np.float64] = pickle.load(fp) - - return vecs[ids] - - def __call__( - self, - queries: str | list[str] | NDArray[np.float64], - k: int, - **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: - if isinstance(queries, str): - queries = [queries] - - if isinstance(queries[0], str): - str_queries: list[str] = [str(q) for q in queries] - embedded_queries = self.embed(str_queries, **kwargs) - else: - embedded_queries = np.asarray(queries, dtype=np.float32) - - if self.faiss_index is None: - raise ValueError("Index not loaded") - - distances, indicies = self.faiss_index.search(embedded_queries, k) - - return distances, indicies diff --git a/lotus/models/faiss_rm.py b/lotus/models/faiss_rm.py new file mode 100644 index 00000000..205129df --- /dev/null +++ b/lotus/models/faiss_rm.py @@ -0,0 +1,62 @@ +import os +import pickle +from abc import abstractmethod +from typing import Any + +import faiss +import numpy as np +from numpy.typing import NDArray + +from lotus.models.rm import RM +from lotus.types import RMOutput + + +class FaissRM(RM): + def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT): + super().__init__() + self.factory_string = factory_string + self.metric = metric + self.index_dir: str | None = None + self.faiss_index: faiss.Index | None = None + self.vecs: NDArray[np.float64] | None = None + + def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + vecs = self._embed(docs) + self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) + self.faiss_index.add(vecs) + self.index_dir = index_dir + + os.makedirs(index_dir, exist_ok=True) + with open(f"{index_dir}/vecs", "wb") as fp: + pickle.dump(vecs, fp) + faiss.write_index(self.faiss_index, f"{index_dir}/index") + + def load_index(self, index_dir: str) -> None: + self.index_dir = index_dir + self.faiss_index = faiss.read_index(f"{index_dir}/index") + with open(f"{index_dir}/vecs", "rb") as fp: + self.vecs = pickle.load(fp) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + with open(f"{index_dir}/vecs", "rb") as fp: + vecs: NDArray[np.float64] = pickle.load(fp) + return vecs[ids] + + def __call__(self, queries: str | list[str] | NDArray[np.float64], K: int, **kwargs: dict[str, Any]) -> RMOutput: + if isinstance(queries, str): + queries = [queries] + + if isinstance(queries[0], str): + embedded_queries = self._embed([str(q) for q in queries]) + else: + embedded_queries = np.asarray(queries, dtype=np.float32) + + if self.faiss_index is None: + raise ValueError("Index not loaded") + + distances, indices = self.faiss_index.search(embedded_queries, K) + return RMOutput(distances=distances, indices=indices) + + @abstractmethod + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + pass diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py new file mode 100644 index 00000000..cadb4cf5 --- /dev/null +++ b/lotus/models/litellm_rm.py @@ -0,0 +1,29 @@ +import faiss +import numpy as np +from litellm import embedding +from litellm.types.utils import EmbeddingResponse +from numpy.typing import NDArray + +from lotus.models.faiss_rm import FaissRM + + +class LiteLLMRM(FaissRM): + def __init__( + self, + model: str = "text-embedding-3-small", + max_batch_size: int = 64, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + response: EmbeddingResponse = embedding(model=self.model, input=batch) + embeddings = np.array([d["embedding"] for d in response.data]) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/models/reranker.py b/lotus/models/reranker.py index 4e2f54ee..a7fd5996 100644 --- a/lotus/models/reranker.py +++ b/lotus/models/reranker.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from lotus.types import RerankerOutput + class Reranker(ABC): """Abstract class for reranker models.""" @@ -8,15 +10,15 @@ def __init__(self) -> None: pass @abstractmethod - def __call__(self, query: str, docs: list[str], k: int) -> list[int]: + def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: """Invoke the reranker. Args: query (str): The query to use for reranking. docs (list[str]): A list of documents to rerank. - k (int): The number of documents to keep after reranking. + K (int): The number of documents to keep after reranking. Returns: - list[int]: The indicies of the reranked documents. + RerankerOutput: The indicies of the reranked documents. """ pass diff --git a/lotus/models/rm.py b/lotus/models/rm.py index e7cb8ba7..330d7cd5 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -4,12 +4,14 @@ import numpy as np from numpy.typing import NDArray +from lotus.types import RMOutput + class RM(ABC): """Abstract class for retriever models.""" def __init__(self) -> None: - pass + self.index_dir: str | None = None @abstractmethod def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: @@ -31,7 +33,7 @@ def load_index(self, index_dir: str) -> None: pass @abstractmethod - def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Get the vectors from the index. Args: @@ -48,17 +50,17 @@ def get_vectors_from_index(cls, index_dir: str, ids: list[int]) -> NDArray[np.fl def __call__( self, queries: str | list[str] | NDArray[np.float64], - k: int, + K: int, **kwargs: dict[str, Any], - ) -> tuple[list[list[float]], list[list[int]]]: + ) -> RMOutput: """Run top-k search on the index. Args: queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. - k (int): The k to use for top-k search. + K (int): The k to use for top-k search. **kwargs (dict[str, Any]): Additional keyword arguments. Returns: - tuple[list[list[float]], list[list[int]]]: A tuple of (distances, indices) of the top-k vectors + RMOutput: An RMOutput object containing the distances and indices of the top-k vectors. """ pass diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py new file mode 100644 index 00000000..bbcd36f9 --- /dev/null +++ b/lotus/models/sentence_transformers_rm.py @@ -0,0 +1,36 @@ +import faiss +import numpy as np +import torch +from numpy.typing import NDArray +from sentence_transformers import SentenceTransformer + +from lotus.models.faiss_rm import FaissRM + + +class SentenceTransformersRM(FaissRM): + def __init__( + self, + model: str = "intfloat/e5-base-v2", + max_batch_size: int = 64, + normalize_embeddings: bool = True, + device: str | None = None, + factory_string: str = "Flat", + metric=faiss.METRIC_INNER_PRODUCT, + ): + super().__init__(factory_string, metric) + self.model: str = model + self.max_batch_size: int = max_batch_size + self.normalize_embeddings: bool = normalize_embeddings + self.transformer: SentenceTransformer = SentenceTransformer(model, device=device) + + def _embed(self, docs: list[str]) -> NDArray[np.float64]: + all_embeddings = [] + for i in range(0, len(docs), self.max_batch_size): + batch = docs[i : i + self.max_batch_size] + torch_embeddings = self.transformer.encode( + batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings + ) + assert isinstance(torch_embeddings, torch.Tensor) + cpu_embeddings = torch_embeddings.cpu().numpy() + all_embeddings.append(cpu_embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 49da6a57..d9feb20f 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.types import RerankerOutput, RMOutput @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -55,9 +56,9 @@ def __call__( search_K = K while True: - scores, doc_idxs = rm(query, search_K) - doc_idxs = doc_idxs[0] - scores = scores[0] + rm_output: RMOutput = rm(query, search_K) + doc_idxs = rm_output.indices[0] + scores = rm_output.distances[0] assert len(doc_idxs) == len(scores) postfiltered_doc_idxs = [] @@ -83,7 +84,8 @@ def __call__( if n_rerank is not None: docs = new_df[col_name].tolist() - reranked_idxs = lotus.settings.reranker(query, docs, n_rerank) + reranked_output: RerankerOutput = lotus.settings.reranker(query, docs, n_rerank) + reranked_idxs = reranked_output.indices new_df = new_df.iloc[reranked_idxs] return new_df diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index d0094f74..04be885f 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -3,6 +3,8 @@ import pandas as pd import lotus +from lotus.models import RM +from lotus.types import RMOutput @pd.api.extensions.register_dataframe_accessor("sem_sim_join") @@ -46,8 +48,11 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - # get rmodel and index rm = lotus.settings.rm + if not isinstance(rm, RM): + raise ValueError( + "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + ) # load query embeddings from index if they exist if left_on in self._obj.attrs.get("index_dirs", []): @@ -71,7 +76,9 @@ def __call__( rm.load_index(col_index_dir) assert rm.index_dir == col_index_dir - distances, indices = rm(queries, K) + rm_output: RMOutput = rm(queries, K) + distances = rm_output.distances + indices = rm_output.indices other_index_set = set(other.index) join_results = [] diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 43190e9a..1db8b514 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -159,7 +159,7 @@ def llm_naive_sort( def llm_quicksort( docs: list[str], user_instruction: str, - k: int, + K: int, embedding: bool = False, strategy: str | None = None, cascade_threshold: float | None = None, @@ -170,7 +170,7 @@ def llm_quicksort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. embedding (bool): Whether to use embedding optimization. cascade_threshold (float | None): The confidence threshold for cascading to a larger model. @@ -187,14 +187,14 @@ def llm_quicksort( stats["total_small_calls"] = 0 stats["total_large_calls"] = 0 - def partition(indexes: list[int], low: int, high: int, k: int) -> int: + def partition(indexes: list[int], low: int, high: int, K: int) -> int: nonlocal stats i = low - 1 if embedding: # With embedding optimization - if k <= high - low: - pivot_value = heapq.nsmallest(k, indexes[low : high + 1])[-1] + if K <= high - low: + pivot_value = heapq.nsmallest(K, indexes[low : high + 1])[-1] else: pivot_value = heapq.nsmallest(int((high - low + 1) / 2), indexes[low : high + 1])[-1] pivot_index = indexes.index(pivot_value) @@ -231,21 +231,21 @@ def partition(indexes: list[int], low: int, high: int, k: int) -> int: indexes[i + 1], indexes[high] = indexes[high], indexes[i + 1] return i + 1 - def quicksort_recursive(indexes: list[int], low: int, high: int, k: int) -> None: + def quicksort_recursive(indexes: list[int], low: int, high: int, K: int) -> None: if high <= low: return if low < high: - pi = partition(indexes, low, high, k) + pi = partition(indexes, low, high, K) left_size = pi - low - if left_size + 1 >= k: - quicksort_recursive(indexes, low, pi - 1, k) + if left_size + 1 >= K: + quicksort_recursive(indexes, low, pi - 1, K) else: quicksort_recursive(indexes, low, pi - 1, left_size) - quicksort_recursive(indexes, pi + 1, high, k - left_size - 1) + quicksort_recursive(indexes, pi + 1, high, K - left_size - 1) indexes = list(range(len(docs))) - quicksort_recursive(indexes, 0, len(indexes) - 1, k) + quicksort_recursive(indexes, 0, len(indexes) - 1, K) return SemanticTopKOutput(indexes=indexes, stats=stats) @@ -273,7 +273,7 @@ def __lt__(self, other: "HeapDoc") -> bool: def llm_heapsort( docs: list[str], user_instruction: str, - k: int, + K: int, strategy: str | None = None, ) -> SemanticTopKOutput: """ @@ -282,7 +282,7 @@ def llm_heapsort( Args: docs (list[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - k (int): The number of documents to return. + K (int): The number of documents to return. Returns: SemanticTopKOutput: The indexes of the top k documents and stats. @@ -292,7 +292,7 @@ def llm_heapsort( HeapDoc.strategy = strategy N = len(docs) heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] - heap = heapq.nsmallest(k, heap) + heap = heapq.nsmallest(K, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] stats = {"total_tokens": HeapDoc.total_tokens, "total_llm_calls": HeapDoc.num_calls} diff --git a/lotus/types.py b/lotus/types.py index d6b3443e..28cbcfe9 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -4,6 +4,9 @@ from pydantic import BaseModel +################################################################################ +# Mixins +################################################################################ class StatsMixin(BaseModel): stats: dict[str, Any] | None = None @@ -13,6 +16,35 @@ class LogprobsMixin(BaseModel): logprobs: list[list[ChatCompletionTokenLogprob]] | None = None +################################################################################ +# LM related +################################################################################ +class LMOutput(LogprobsMixin): + outputs: list[str] + + +class LMStats(BaseModel): + class TotalUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost: float = 0.0 + + total_usage: TotalUsage = TotalUsage() + + +class LogprobsForCascade(BaseModel): + tokens: list[list[str]] + confidences: list[list[float]] + + +class LogprobsForFilterCascade(LogprobsForCascade): + true_probs: list[float] + + +################################################################################ +# Semantic operation outputs +################################################################################ class SemanticMapPostprocessOutput(BaseModel): raw_outputs: list[str] outputs: list[str] @@ -58,24 +90,16 @@ class SemanticTopKOutput(StatsMixin): indexes: list[int] -class LMOutput(LogprobsMixin): - outputs: list[str] +################################################################################ +# RM related +################################################################################ +class RMOutput(BaseModel): + distances: list[list[float]] + indices: list[list[int]] -class LogprobsForCascade(BaseModel): - tokens: list[list[str]] - confidences: list[list[float]] - - -class LogprobsForFilterCascade(LogprobsForCascade): - true_probs: list[float] - - -class LMStats(BaseModel): - class TotalUsage(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - total_cost: float = 0.0 - - total_usage: TotalUsage = TotalUsage() +################################################################################ +# Reranker related +################################################################################ +class RerankerOutput(BaseModel): + indices: list[int] From ab9d0fd4fe07c4199b1f020940b61b6d63fce34e Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Tue, 5 Nov 2024 18:15:09 -0800 Subject: [PATCH 27/28] Make examples explicitly state model --- README.md | 2 +- docs/quickstart.rst | 4 ++-- examples/op_examples/agg.py | 4 ++-- examples/op_examples/cluster.py | 4 ++-- examples/op_examples/dedup.py | 2 +- examples/op_examples/filter.py | 2 +- examples/op_examples/join.py | 2 +- examples/op_examples/map.py | 2 +- examples/op_examples/map_fewshot.py | 2 +- examples/op_examples/partition.py | 2 +- examples/op_examples/search.py | 6 +++--- examples/op_examples/sim_join.py | 4 ++-- examples/op_examples/top_k.py | 2 +- 13 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 2c57c51c..cdd56c62 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ import lotus from lotus.models import LM # configure the LM, and remember to export your API key -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) # create dataframes with course names and skills diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 2a9f2761..b2fcd059 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -53,8 +53,8 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg from lotus.models import SentenceTransformersRM, LM # Configure models for LOTUS - lm = LM() - rm = SentenceTransformersRM() + lm = LM(model="gpt-4o-mini") + rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index 6ad9356f..206e3cc8 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -3,8 +3,8 @@ import lotus from lotus.models import LM, SentenceTransformersRM -lm = LM() -rm = SentenceTransformersRM() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index 9c6697ad..e117b249 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -3,8 +3,8 @@ import lotus from lotus.models import LM, SentenceTransformersRM -lm = LM() -rm = SentenceTransformersRM() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 8c89aebd..1494df95 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -3,7 +3,7 @@ import lotus from lotus.models import SentenceTransformersRM -rm = SentenceTransformersRM() +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(rm=rm) data = { diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index ee96e876..a1acc00d 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -3,7 +3,7 @@ import lotus from lotus.models import LM -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/join.py b/examples/op_examples/join.py index 3b8fb30f..7291c575 100644 --- a/examples/op_examples/join.py +++ b/examples/op_examples/join.py @@ -3,7 +3,7 @@ import lotus from lotus.models import LM -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map.py b/examples/op_examples/map.py index a3ea765b..4fb163f2 100644 --- a/examples/op_examples/map.py +++ b/examples/op_examples/map.py @@ -3,7 +3,7 @@ import lotus from lotus.models import LM -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/map_fewshot.py b/examples/op_examples/map_fewshot.py index b3bf07fb..365f7c9a 100644 --- a/examples/op_examples/map_fewshot.py +++ b/examples/op_examples/map_fewshot.py @@ -3,7 +3,7 @@ import lotus from lotus.models import LM -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index c1c7174e..932b170b 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -4,7 +4,7 @@ from lotus.models import LM, SentenceTransformersRM lm = LM(max_tokens=2048) -rm = SentenceTransformersRM() +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") lotus.settings.configure(lm=lm, rm=rm) data = { diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index 60f04190..c9382aae 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -3,9 +3,9 @@ import lotus from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM -lm = LM() -rm = SentenceTransformersRM() -reranker = CrossEncoderReranker() +lm = LM(model="gpt-4o-mini") +rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +reranker = CrossEncoderReranker(model="mixedbread-ai/mxbai-rerank-large-v1") lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) data = { diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 200a7c43..efc97427 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -3,8 +3,8 @@ import lotus from lotus.models import LM, LiteLLMRM -lm = LM() -# rm = SentenceTransformersRM() +lm = LM(model="gpt-4o-mini") +# rm = SentenceTransformersRM(model="intfloat/e5-base-v2") rm = LiteLLMRM() lotus.settings.configure(lm=lm, rm=rm) diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 8ffaf7b3..8654ea18 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -3,7 +3,7 @@ import lotus from lotus.models import LM -lm = LM() +lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) data = { From 298618a151ed8c0a8e2f16bbffd86fa841fb2318 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Tue, 5 Nov 2024 18:26:28 -0800 Subject: [PATCH 28/28] Add test for format logprobs for filter cascade --- .github/tests/lm_tests.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 4ebc02a1..ae68c109 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -241,6 +241,22 @@ def test_join_cascade(setup_models): assert stats["filters_resolved_by_helper_model"] == 0, stats +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_format_logprobs_for_filter_cascade(setup_models, model): + lm = setup_models[model] + messages = [ + [{"role": "user", "content": "True or False: The sky is blue?"}], + ] + response = lm(messages, logprobs=True) + formatted_logprobs = lm.format_logprobs_for_filter_cascade(response.logprobs) + true_probs = formatted_logprobs.true_probs + assert len(true_probs) == 1 + + # Very safe (in practice its ~1) + assert true_probs[0] > 0.8 + assert len(formatted_logprobs.tokens) == len(formatted_logprobs.confidences) + + ################################################################################ # Token counting tests ################################################################################