Skip to content

Commit

Permalink
Added the MultipleCompletionLLMModel class used in LDP (#3)
Browse files Browse the repository at this point in the history
* Added the MultipleCompletionLLMModel class used in LDP

---------

Co-authored-by: Mayk Caldas <[email protected]>
  • Loading branch information
maykcaldas and maykcaldas authored Nov 27, 2024
1 parent 4e95024 commit 5453a6d
Show file tree
Hide file tree
Showing 13 changed files with 1,023 additions and 125 deletions.
30 changes: 28 additions & 2 deletions llmclient/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
from llmclient.llms import LLMModel
from llmclient.types import LLMResult
from llmclient.embeddings import (
EmbeddingModel,
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
)
from llmclient.llms import LiteLLMModel, LLMModel, MultipleCompletionLLMModel
from llmclient.types import (
Chunk,
Embeddable,
LLMResult,
)
from llmclient.version import __version__

__all__ = [
"Chunk",
"Embeddable",
"EmbeddingModel",
"HybridEmbeddingModel",
"LLMModel",
"LLMResult",
"LLMResult",
"LiteLLMEmbeddingModel",
"LiteLLMModel",
"MultipleCompletionLLMModel",
"SentenceTransformerEmbeddingModel",
"SparseEmbeddingModel",
"__version__",
"embedding_model_factory",
"embedding_model_factory",
]
5 changes: 5 additions & 0 deletions llmclient/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import litellm

# Estimate from OpenAI's FAQ
# https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
CHARACTERS_PER_TOKEN_ASSUMPTION: float = 4.0
# Added tokens from user/role message
# Need to add while doing rate limits
# Taken from empirical counts in tests
EXTRA_TOKENS_FROM_USER_ROLE: int = 7

MODEL_COST_MAP = litellm.get_model_cost_map("")
Expand Down
6 changes: 1 addition & 5 deletions llmclient/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

from llmclient.constants import CHARACTERS_PER_TOKEN_ASSUMPTION, MODEL_COST_MAP
from llmclient.rate_limiter import GLOBAL_LIMITER


def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]:
"""Get retrying configuration for litellm.acompletion and litellm.aembedding."""
return {"num_retries": 3, "timeout": timeout}
from llmclient.utils import get_litellm_retrying_config


class EmbeddingModes(StrEnum):
Expand Down
243 changes: 237 additions & 6 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import functools
import json
from abc import ABC
from collections.abc import (
AsyncGenerator,
Expand All @@ -13,13 +14,18 @@
from inspect import isasyncgenfunction, signature
from typing import (
Any,
ClassVar,
Self,
TypeVar,
cast,
)

import litellm
from aviary.core import (
Message,
Tool,
ToolRequestMessage,
ToolsAdapter,
ToolSelector,
)
from pydantic import (
Expand All @@ -41,7 +47,7 @@
from llmclient.prompts import default_system_prompt
from llmclient.rate_limiter import GLOBAL_LIMITER
from llmclient.types import Chunk, LLMResult
from llmclient.utils import is_coroutine_callable
from llmclient.utils import get_litellm_retrying_config, is_coroutine_callable

if not IS_PYTHON_BELOW_312:
_DeploymentTypedDictValidator = TypeAdapter(
Expand Down Expand Up @@ -120,11 +126,6 @@ async def do_callbacks(
f(*args, **kwargs)


def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]:
"""Get retrying configuration for litellm.acompletion and litellm.aembedding."""
return {"num_retries": 3, "timeout": timeout}


class LLMModel(ABC, BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

Expand Down Expand Up @@ -582,3 +583,233 @@ async def select_tool(
model_name=self.name, acompletion=self.router.acompletion
)
return await tool_selector(*selection_args, **selection_kwargs)


class MultipleCompletionLLMModel(BaseModel):
"""Run n completions at once, all starting from the same messages."""

model_config = ConfigDict(extra="forbid")

# this should keep the original model
# if fine-tuned, this should still refer to the base model
name: str = "unknown"
config: dict = Field(
default={
"model": "gpt-3.5-turbo", # Default model should have cheap input/output for testing
"temperature": 0.1,
}
)
encoding: Any | None = None

def __str__(self) -> str:
return f"{type(self).__name__} {self.name}"

@model_validator(mode="after")
def set_model_name(self) -> Self:
if (
self.config.get("model") in {"gpt-3.5-turbo", None}
and self.name != "unknown"
) or (self.name != "unknown" and "model" not in self.config):
self.config["model"] = self.name
elif "model" in self.config and self.name == "unknown":
self.name = self.config["model"]
# note we do not consider case where both are set
# because that could be true if the model is fine-tuned
return self

async def achat(
self, messages: Iterable[Message], **kwargs
) -> litellm.ModelResponse:
return await litellm.acompletion(
messages=[m.model_dump(by_alias=True) for m in messages],
**(self.config | kwargs),
)

async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenerator:
return cast(
AsyncGenerator,
await litellm.acompletion(
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
),
)

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# > `required` means the model must call one or more tools.
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"

async def call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> list[LLMResult]:
start_clock = asyncio.get_running_loop().time()

# Deal with tools. Note OpenAI throws a 400 response if tools is empty:
# > Invalid 'tools': empty array. Expected an array with minimum length 1,
# > but got an empty array instead.
# So, circumvent this behavior if tools in (None, [])
if tools:
chat_kwargs["tools"] = ToolsAdapter.dump_python(
tools, exclude_none=True, by_alias=True
)
if tool_choice is not None:
chat_kwargs["tool_choice"] = (
{
"type": "function",
"function": {"name": tool_choice.info.name},
}
if isinstance(tool_choice, Tool)
else tool_choice
)

# deal with specifying output type
if output_type is not None:
schema = json.dumps(output_type.model_json_schema(mode="serialization"))
schema_msg = f"Respond following this JSON schema:\n\n{schema}"
# Get the system prompt and its index, or the index to add it
i, system_prompt = next(
((i, m) for i, m in enumerate(messages) if m.role == "system"),
(0, None),
)
messages = [
*messages[:i],
(
system_prompt.append_text(schema_msg, inplace=False)
if system_prompt
else Message(role="system", content=schema_msg)
),
*messages[i + 1 if system_prompt else i :],
]
chat_kwargs["response_format"] = {"type": "json_object"}

# add static configuration to kwargs
chat_kwargs = self.config | chat_kwargs
n = chat_kwargs.get("n", 1) # number of completions
if n < 1:
raise ValueError("Number of completions (n) must be >= 1.")

prompt = [
(
m
if not isinstance(m, ToolRequestMessage) or m.tool_calls
# OpenAI doesn't allow for empty tool_calls lists, so downcast empty
# ToolRequestMessage to Message here
else Message(role=m.role, content=m.content)
)
for m in messages
]
results: list[LLMResult] = []

if callbacks is None:
completion: litellm.ModelResponse = await self.achat(prompt, **chat_kwargs)
if output_type is not None:
validate_json_completion(completion, output_type)

for choice in completion.choices:
if isinstance(choice, litellm.utils.StreamingChoices):
raise NotImplementedError("Streaming is not yet supported.")

if (
tools is not None # Allows for empty tools list
or choice.finish_reason == "tool_calls"
or (getattr(choice.message, "tool_calls", None) is not None)
):
serialized_choice_message = choice.message.model_dump()
serialized_choice_message["tool_calls"] = (
serialized_choice_message.get("tool_calls") or []
)
output_messages: list[Message | ToolRequestMessage] = [
ToolRequestMessage(**serialized_choice_message)
]
else:
output_messages = [Message(**choice.message.model_dump())]

results.append(
LLMResult(
model=self.name,
config=chat_kwargs,
prompt=prompt,
messages=output_messages,
logprob=sum_logprobs(choice),
system_fingerprint=completion.system_fingerprint,
# Note that these counts are aggregated over all choices
completion_count=completion.usage.completion_tokens, # type: ignore[attr-defined,unused-ignore]
prompt_count=completion.usage.prompt_tokens, # type: ignore[attr-defined,unused-ignore]
)
)
else:
if tools:
raise NotImplementedError("Using tools with callbacks is not supported")
if n > 1:
raise NotImplementedError(
"Multiple completions with callbacks is not supported"
)
result = LLMResult(model=self.name, config=chat_kwargs, prompt=prompt)

sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]
stream_completion = await self.achat_iter(messages, **chat_kwargs)
text_result = []
role = "assistant"

async for chunk in stream_completion:
delta = chunk.choices[0].delta
role = delta.role or role
if delta.content:
s = delta.content
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
asyncio.get_running_loop().time() - start_clock
)
text_result.append(s)
[await f(s) for f in async_callbacks]
[f(s) for f in sync_callbacks]
if hasattr(chunk, "usage"):
result.prompt_count = chunk.usage.prompt_tokens

output = "".join(text_result)
result.completion_count = litellm.token_counter(
model=self.name,
text=output,
)
# TODO: figure out how tools stream, and log probs
result.messages = [Message(role=role, content=output)]
results.append(result)

if not results:
# This happens in unit tests. We should probably not keep this block around
# long-term. Previously, we would emit an empty ToolRequestMessage if
# completion.choices were empty, so I am replicating that here.
results.append(
LLMResult(
model=self.name,
config=chat_kwargs,
prompt=prompt,
messages=[ToolRequestMessage(tool_calls=[])],
)
)

end_clock = asyncio.get_running_loop().time()

for result in results:
# Manually update prompt count if not set, which can
# happen if the target model doesn't support 'include_usage'
if not result.prompt_count and result.messages:
result.prompt_count = litellm.token_counter(
model=self.name,
messages=[m.model_dump() for m in result.messages],
)

# update with server-side counts
result.seconds_to_last_token = end_clock - start_clock

return results
19 changes: 18 additions & 1 deletion llmclient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import UUID, uuid4

import litellm
from aviary.core import Message
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -67,11 +68,14 @@ class LLMResult(BaseModel):
alias="answer_id",
)
name: str | None = None
prompt: str | list[dict] | None = Field(
prompt: str | list[dict] | Message | list[Message] | None = Field(
default=None,
description="Optional prompt (str) or list of serialized prompts (list[dict]).",
)
text: str = ""
messages: list[Message] | None = Field(
default=None, description="Messages received from the LLM."
)
prompt_count: int = 0
completion_count: int = 0
model: str
Expand All @@ -82,6 +86,9 @@ class LLMResult(BaseModel):
seconds_to_last_token: float = Field(
default=0.0, description="Delta time (sec) to last response token's arrival."
)
logprob: float | None = Field(
default=None, description="Sum of logprobs in the completion."
)

def __str__(self) -> str:
return self.text
Expand All @@ -98,3 +105,13 @@ def cost(self) -> float:
except KeyError:
logger.warning(f"Could not find cost for model {self.model}.")
return 0.0

# These two methods were implemented in ldp, but not in pqa. Check if they're necessary
# @property
# def provider(self) -> str:
# """Get the model provider's name (e.g. "openai", "mistral")."""
# return litellm.get_llm_provider(self.model)[1]

# def get_supported_openai_params(self) -> list[str] | None:
# """Get the supported OpenAI parameters for the model."""
# return litellm.get_supported_openai_params(self.model)
Loading

0 comments on commit 5453a6d

Please sign in to comment.