Skip to content

Commit

Permalink
Port to LiteLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Oct 29, 2024
1 parent 9dab975 commit 070d01f
Show file tree
Hide file tree
Showing 26 changed files with 201 additions and 508 deletions.
7 changes: 3 additions & 4 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")
Expand All @@ -11,8 +11,8 @@
@pytest.fixture
def setup_models():
# Setup GPT models
gpt_4o_mini = OpenAIModel(model="gpt-4o-mini")
gpt_4o = OpenAIModel(model="gpt-4o")
gpt_4o_mini = LM(model="gpt-4o-mini")
gpt_4o = LM(model="gpt-4o")
return gpt_4o_mini, gpt_4o


Expand Down Expand Up @@ -57,7 +57,6 @@ def test_filter_cascade(setup_models):
"Everything is going as planned, couldn't be happier.",
"Feeling super motivated and ready to take on challenges!",
"I appreciate all the small things that bring me joy.",

# Negative examples
"I am very sad.",
"Today has been really tough; I feel exhausted.",
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/agg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import E5Model, OpenAIModel
from lotus.models import LM, E5Model

lm = OpenAIModel()
lm = LM()
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm)
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/cluster.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import E5Model, OpenAIModel
from lotus.models import LM, E5Model

lm = OpenAIModel()
lm = LM()
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm)
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
6 changes: 3 additions & 3 deletions examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

gpt_35_turbo = OpenAIModel("gpt-3.5-turbo")
gpt_4o = OpenAIModel("gpt-4o")
gpt_35_turbo = LM("gpt-3.5-turbo")
gpt_4o = LM("gpt-4o")

lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_35_turbo)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/join.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/map.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/map_fewshot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/partition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import E5Model, OpenAIModel
from lotus.models import LM, E5Model

lm = OpenAIModel(max_tokens=2048)
lm = LM(max_tokens=2048)
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm)
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import CrossEncoderModel, E5Model, OpenAIModel
from lotus.models import LM, CrossEncoderModel, E5Model

lm = OpenAIModel()
lm = LM()
rm = E5Model()
reranker = CrossEncoderModel()

Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/sim_join.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import E5Model, OpenAIModel
from lotus.models import LM, E5Model

lm = OpenAIModel()
lm = LM()
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm)
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/top_k.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/provider_examples/oai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel()
lm = LM()

lotus.settings.configure(lm=lm)
data = {
Expand Down
4 changes: 2 additions & 2 deletions examples/provider_examples/ollama.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel(
lm = LM(
api_base="http://localhost:11434/v1",
model="llama3.2",
hf_name="meta-llama/Llama-3.2-3B-Instruct",
Expand Down
4 changes: 2 additions & 2 deletions examples/provider_examples/vllm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel
from lotus.models import LM

lm = OpenAIModel(
lm = LM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
api_base="http://localhost:8000/v1",
provider="vllm",
Expand Down
2 changes: 0 additions & 2 deletions lotus/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from lotus.models.cross_encoder_model import CrossEncoderModel
from lotus.models.e5_model import E5Model
from lotus.models.lm import LM
from lotus.models.openai_model import OpenAIModel
from lotus.models.reranker import Reranker
from lotus.models.rm import RM

__all__ = [
"OpenAIModel",
"E5Model",
"ColBERTv2Model",
"CrossEncoderModel",
Expand Down
129 changes: 67 additions & 62 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,67 @@
from abc import ABC, abstractmethod
from typing import Any


class LM(ABC):
"""Abstract class for language models."""

def _init__(self):
pass

@abstractmethod
def count_tokens(self, prompt: str | list) -> int:
"""
Counts the number of tokens in the given prompt.
Args:
prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages.
Returns:
int: The number of tokens in the prompt.
"""
pass

def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]:
"""
Formats the logprobs for the cascade.
Args:
logprobs (list): The logprobs to format.
Returns:
tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences.
"""
pass

@abstractmethod
def __call__(
self, messages_batch: list | list[list], **kwargs: dict[str, Any]
) -> list[str] | tuple[list[str], list[dict[str, Any]]]:
"""Invoke the LLM.
Args:
messages_batch (list | list[list]): Either one prompt or a list of prompts in message format.
kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.
Returns:
list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
then a list of logprobs is also returned.
"""
pass

@property
@abstractmethod
def max_ctx_len(self) -> int:
"""The maximum context length of the LLM."""
pass

@property
@abstractmethod
def max_tokens(self) -> int:
"""The maximum number of tokens that can be generated by the LLM."""
pass
import numpy as np
from litellm import batch_completion, token_counter
from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse

from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade


class LM:
def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs):
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []

def __call__(self, messages=None, **kwargs) -> LMOutput:
kwargs = {**self.kwargs, **kwargs}
if kwargs.get("logprobs", False):
kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10)

responses: list[ModelResponse] = batch_completion(model=self.model, messages=messages, **kwargs)
outputs = [self._get_top_choice(resp) for resp in responses]
logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None

return LMOutput(outputs=outputs, logprobs=logprobs)

def _get_top_choice(self, response: ModelResponse) -> str:
return response.choices[0].message.content

def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]:
logprobs = response.choices[0].logprobs["content"]
return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs]

def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade:
all_tokens = []
all_confidences = []
for resp in range(len(logprobs)):
tokens = [logprob.token for logprob in logprobs[resp]]
confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]]
all_tokens.append(tokens)
all_confidences.append(confidences)
return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences)

def format_logprobs_for_filter_cascade(
self, logprobs: list[list[ChatCompletionTokenLogprob]]
) -> LogprobsForFilterCascade:
all_tokens = []
all_confidences = []
all_true_probs = []

for resp in range(len(logprobs)):
all_tokens.append([logprob.token for logprob in logprobs[resp]])
all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]])
top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]}
true_prob, false_prob = 0, 0
if top_logprobs and "True" in top_logprobs and "False" in top_logprobs:
true_prob = np.exp(top_logprobs["True"])
false_prob = np.exp(top_logprobs["False"])
all_true_probs.append(true_prob / (true_prob + false_prob))
else:
all_true_probs.append(1 if "True" in top_logprobs else 0)
return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs)

def count_tokens(self, messages: list[dict[str, str]] | str) -> int:
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
return token_counter(model=self.model, messages=messages)
Loading

0 comments on commit 070d01f

Please sign in to comment.