Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: litellm for multiple model support #304

Merged
merged 8 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Callable
from textwrap import dedent
from temporalio import activity
from litellm import acompletion
from agents_api.models.entry.entries_summarization import (
get_toplevel_entries_query,
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from agents_api.model_registry import get_model_client
from ..env import summarization_model_name


Expand Down Expand Up @@ -129,9 +129,7 @@ async def run_prompt(
**kwargs,
) -> str:
prompt = make_prompt(dialog, previous_memories, **kwargs)
client = get_model_client(model)

response = await client.chat.completions.create(
response = await acompletion(
model=model,
messages=[
{
Expand Down
38 changes: 11 additions & 27 deletions agents-api/agents_api/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
"""

from typing import Dict
from agents_api.clients.model import julep_client, openai_client
from agents_api.clients.worker.types import ChatML
from agents_api.common.exceptions.agents import (
AgentModelNotValid,
MissingAgentModelAPIKeyError,
)
from openai import AsyncOpenAI

import litellm
from litellm.utils import get_valid_models

GPT4_MODELS: Dict[str, int] = {
# stable model names:
Expand Down Expand Up @@ -101,42 +100,26 @@

CHAT_MODELS = {**GPT4_MODELS, **TURBO_MODELS, **CLAUDE_MODELS}

ALL_AVAILABLE_MODELS = {
**JULEP_MODELS,
**GPT4_MODELS,
**TURBO_MODELS,
**GPT3_5_MODELS,
**GPT3_MODELS,
# **CLAUDE_MODELS,
}

ALL_AVAILABLE_MODELS = litellm.model_list + list(JULEP_MODELS.keys())
VALID_MODELS = get_valid_models() + list(JULEP_MODELS.keys())


def validate_configuration(model: str):
"""
Validates the model specified in the request
"""
if model not in ALL_AVAILABLE_MODELS:
raise AgentModelNotValid(model, list(ALL_AVAILABLE_MODELS.keys()))
model_client = get_model_client(model)
if model_client.api_key == "":
raise AgentModelNotValid(model, ALL_AVAILABLE_MODELS)
elif model not in get_valid_models():
raise MissingAgentModelAPIKeyError(model)


def get_model_client(model: str) -> AsyncOpenAI:
"""
Returns the model serving client based on the model
"""
if model in JULEP_MODELS:
return julep_client
elif model in OPENAI_MODELS:
return openai_client


def load_context(init_context: list[ChatML], model: str):
"""
Converts the message history into a format supported by the model.
"""
if model in OPENAI_MODELS:
if model in litellm.utils.get_valid_models():
init_context = [
{
"role": "assistant" if msg.role == "function_call" else msg.role,
Expand All @@ -149,10 +132,11 @@ def load_context(init_context: list[ChatML], model: str):
{"name": msg.name, "role": msg.role, "content": msg.content}
for msg in init_context
]
else:
raise AgentModelNotValid(model, ALL_AVAILABLE_MODELS)
return init_context


# TODO: add type hint for Settings
def get_extra_settings(settings):
extra_settings = (
dict(
Expand All @@ -164,7 +148,7 @@ def get_extra_settings(settings):
preset=settings.preset.name if settings.preset else None,
)
if settings.model in JULEP_MODELS
else None
else {}
)

return extra_settings
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def update_agent(
if isinstance(request.instructions, str):
request.instructions = [request.instructions]

validate_configuration(request.model)
try:
resp = update_agent_query(
agent_id=agent_id,
Expand Down
28 changes: 22 additions & 6 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from openai.types.chat.chat_completion import ChatCompletion
from pydantic import UUID4

import litellm
from litellm import acompletion

from ...autogen.openapi_model import InputChatMLMessage, Tool
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
Expand All @@ -18,8 +21,8 @@
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...model_registry import (
JULEP_MODELS,
get_extra_settings,
get_model_client,
load_context,
)
from ...models.entry.add_entries import add_entries_query
Expand All @@ -28,7 +31,7 @@

from .exceptions import InputTooBigError
from .protocol import Settings

from ...env import model_inference_url, model_api_key

THOUGHTS_STRIP_LEN = 2
MESSAGES_STRIP_LEN = 4
Expand Down Expand Up @@ -319,24 +322,37 @@ async def generate(
) -> ChatCompletion:
init_context = load_context(init_context, settings.model)
tools = None
api_base = None
api_key = None
model = settings.model
if model in JULEP_MODELS:
api_base = model_inference_url
api_key = model_api_key
model = f"openai/{model}"

if settings.tools:
tools = [(tool.model_dump(exclude="id")) for tool in settings.tools]
model_client = get_model_client(settings.model)

extra_body = get_extra_settings(settings)

res = await model_client.chat.completions.create(
model=settings.model,
litellm.drop_params = True
litellm.add_function_to_prompt = True

res = await acompletion(
model=model,
messages=init_context,
max_tokens=settings.max_tokens,
stop=settings.stop,
temperature=settings.temperature,
frequency_penalty=settings.frequency_penalty,
extra_body=extra_body,
top_p=settings.top_p,
presence_penalty=settings.presence_penalty,
stream=settings.stream,
tools=tools,
response_format=settings.response_format,
api_base=api_base,
api_key=api_key,
**extra_body,
)
return res

Expand Down
Loading
Loading