Skip to content

Commit

Permalink
feat(agents-api): samantha model throught litellm
Browse files Browse the repository at this point in the history
  • Loading branch information
alt-glitch committed May 1, 2024
1 parent fc0d80c commit 3bef712
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
11 changes: 1 addition & 10 deletions agents-api/agents_api/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from typing import Dict
from agents_api.clients.model import julep_client, openai_client
from agents_api.clients.worker.types import ChatML
from agents_api.common.exceptions.agents import (
AgentModelNotValid,
Expand All @@ -12,6 +11,7 @@
from openai import AsyncOpenAI
import litellm
from litellm.utils import get_valid_models
from agents_api.env import model_inference_url, model_api_key

GPT4_MODELS: Dict[str, int] = {
# stable model names:
Expand Down Expand Up @@ -116,15 +116,6 @@ def validate_configuration(model: str):
raise MissingAgentModelAPIKeyError(model)


def get_model_client(model: str) -> AsyncOpenAI:
"""
Returns the model serving client based on the model
"""
if model in JULEP_MODELS:
return julep_client
elif model in OPENAI_MODELS:
return openai_client


def load_context(init_context: list[ChatML], model: str):
"""
Expand Down
15 changes: 13 additions & 2 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...model_registry import (
JULEP_MODELS,
get_extra_settings,
load_context,
)
Expand All @@ -30,7 +31,7 @@

from .exceptions import InputTooBigError
from .protocol import Settings

from ...env import model_inference_url, model_api_key

THOUGHTS_STRIP_LEN = 2
MESSAGES_STRIP_LEN = 4
Expand Down Expand Up @@ -321,15 +322,23 @@ async def generate(
) -> ChatCompletion:
init_context = load_context(init_context, settings.model)
tools = None
api_base = None
api_key = None
if settings.model in JULEP_MODELS:
api_base = model_inference_url
api_key = model_api_key
model = f"openai/{settings.model}"

if settings.tools:
tools = [(tool.model_dump(exclude="id")) for tool in settings.tools]

extra_body = get_extra_settings(settings)

litellm.drop_params = True
litellm.add_function_to_prompt = True

res = await acompletion(
model=settings.model,
model=model,
messages=init_context,
max_tokens=settings.max_tokens,
stop=settings.stop,
Expand All @@ -340,6 +349,8 @@ async def generate(
stream=settings.stream,
tools=tools,
response_format=settings.response_format,
api_base=api_base,
api_key=api_key,
**extra_body
)
return res
Expand Down

0 comments on commit 3bef712

Please sign in to comment.