-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
201 additions
and
508 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,67 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
|
||
class LM(ABC): | ||
"""Abstract class for language models.""" | ||
|
||
def _init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def count_tokens(self, prompt: str | list) -> int: | ||
""" | ||
Counts the number of tokens in the given prompt. | ||
Args: | ||
prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages. | ||
Returns: | ||
int: The number of tokens in the prompt. | ||
""" | ||
pass | ||
|
||
def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]: | ||
""" | ||
Formats the logprobs for the cascade. | ||
Args: | ||
logprobs (list): The logprobs to format. | ||
Returns: | ||
tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, messages_batch: list | list[list], **kwargs: dict[str, Any] | ||
) -> list[str] | tuple[list[str], list[dict[str, Any]]]: | ||
"""Invoke the LLM. | ||
Args: | ||
messages_batch (list | list[list]): Either one prompt or a list of prompts in message format. | ||
kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. | ||
Returns: | ||
list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, | ||
then a list of logprobs is also returned. | ||
""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def max_ctx_len(self) -> int: | ||
"""The maximum context length of the LLM.""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def max_tokens(self) -> int: | ||
"""The maximum number of tokens that can be generated by the LLM.""" | ||
pass | ||
import numpy as np | ||
from litellm import batch_completion, token_counter | ||
from litellm.types.utils import ChatCompletionTokenLogprob, ModelResponse | ||
|
||
from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade | ||
|
||
|
||
class LM: | ||
def __init__(self, model="gpt-4o-mini", temperature=0.0, max_ctx_len=128000, max_tokens=512, **kwargs): | ||
self.model = model | ||
self.max_ctx_len = max_ctx_len | ||
self.max_tokens = max_tokens | ||
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) | ||
self.history = [] | ||
|
||
def __call__(self, messages=None, **kwargs) -> LMOutput: | ||
kwargs = {**self.kwargs, **kwargs} | ||
if kwargs.get("logprobs", False): | ||
kwargs["top_logprobs"] = kwargs.get("top_logprobs", 10) | ||
|
||
responses: list[ModelResponse] = batch_completion(model=self.model, messages=messages, **kwargs) | ||
outputs = [self._get_top_choice(resp) for resp in responses] | ||
logprobs = [self._get_top_choice_logprobs(resp) for resp in responses] if kwargs.get("logprobs") else None | ||
|
||
return LMOutput(outputs=outputs, logprobs=logprobs) | ||
|
||
def _get_top_choice(self, response: ModelResponse) -> str: | ||
return response.choices[0].message.content | ||
|
||
def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: | ||
logprobs = response.choices[0].logprobs["content"] | ||
return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs] | ||
|
||
def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: | ||
all_tokens = [] | ||
all_confidences = [] | ||
for resp in range(len(logprobs)): | ||
tokens = [logprob.token for logprob in logprobs[resp]] | ||
confidences = [np.exp(logprob.logprob) for logprob in logprobs[resp]] | ||
all_tokens.append(tokens) | ||
all_confidences.append(confidences) | ||
return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences) | ||
|
||
def format_logprobs_for_filter_cascade( | ||
self, logprobs: list[list[ChatCompletionTokenLogprob]] | ||
) -> LogprobsForFilterCascade: | ||
all_tokens = [] | ||
all_confidences = [] | ||
all_true_probs = [] | ||
|
||
for resp in range(len(logprobs)): | ||
all_tokens.append([logprob.token for logprob in logprobs[resp]]) | ||
all_confidences.append([np.exp(logprob.logprob) for logprob in logprobs[resp]]) | ||
top_logprobs = {x.token: np.exp(x.logprob) for x in logprobs[resp]} | ||
true_prob, false_prob = 0, 0 | ||
if top_logprobs and "True" in top_logprobs and "False" in top_logprobs: | ||
true_prob = np.exp(top_logprobs["True"]) | ||
false_prob = np.exp(top_logprobs["False"]) | ||
all_true_probs.append(true_prob / (true_prob + false_prob)) | ||
else: | ||
all_true_probs.append(1 if "True" in top_logprobs else 0) | ||
return LogprobsForFilterCascade(tokens=all_tokens, confidences=all_confidences, true_probs=all_true_probs) | ||
|
||
def count_tokens(self, messages: list[dict[str, str]] | str) -> int: | ||
if isinstance(messages, str): | ||
messages = [{"role": "user", "content": messages}] | ||
return token_counter(model=self.model, messages=messages) |
Oops, something went wrong.