Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

f/generic model support #24

Merged
merged 17 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ COZO_ROCKSDB_DIR=cozo.db
DTYPE=bfloat16
EMBEDDING_SERVICE_URL=http://text-embeddings-inference/embed
GATEWAY_PORT=80
# switch openai model here
GENERATION_AUTH_TOKEN=myauthkey
GENERATION_URL=http://model-api:8000/v1
GPU_MEMORY_UTILIZATION=0.95
Expand All @@ -21,16 +22,16 @@ MODEL_API_KEY=myauthkey
MODEL_API_KEY_HEADER_NAME=Authorization
MODEL_API_URL=http://model-api:8000
MODEL_ID=BAAI/llm-embedder
MODEL_NAME=julep-ai/samantha-1-turbo
# MODEL_NAME = "julep-ai/samantha-1-turbo-awq"
# switch local model here
MODEL_NAME = "julep-ai/samantha-1-turbo-awq"
SKIP_CHECK_DEVELOPER_HEADERS=true
SUMMARIZATION_TOKENS_THRESHOLD=2048
TEMPERATURE_SCALING_FACTOR=0.9
TEMPERATURE_SCALING_POWER=0.9
TEMPORAL_ENDPOINT=temporal:7233
TEMPORAL_NAMESPACE=default
TEMPORAL_WORKER_URL=temporal:7233
TP_SIZE=2
TP_SIZE=1
TRUNCATE_EMBED_TEXT=true
TRAEFIK_LOG_LEVEL=DEBUG
WORKER_URL=temporal:7233
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ ngrok*
*.env
*.pyc
*/node_modules/
.devcontainer
node_modules/
package-lock.json
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these gitignored? be careful

package.json
9 changes: 8 additions & 1 deletion agents-api/agents_api/common/exceptions/agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from uuid import UUID
from . import BaseCommonException

from agents_api.model_registry import ALL_AVAILABLE_MODELS

class BaseAgentException(BaseCommonException):
pass
Expand All @@ -26,3 +26,10 @@ def __init__(self, agent_id: UUID | str, doc_id: UUID | str):
super().__init__(
f"Doc {str(doc_id)} not found for agent {str(agent_id)}", http_code=404
)

class AgentModelNotValid(BaseAgentException):
def __init__(self, model: str):
super().__init__(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys()), http_code=400
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class AgentDefaultSettings(BaseModel):
min_p: float = 0.01


ModelType = Literal["julep-ai/samantha-1", "julep-ai/samantha-1-turbo"]
# ModelType = Literal["julep-ai/samantha-1", "julep-ai/samantha-1-turbo", "gpt-4"]
17 changes: 14 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, validator

from .agents import ModelType, AgentDefaultSettings
from .agents import AgentDefaultSettings

from model_registry import ALL_AVAILABLE_MODELS

class SessionSettings(AgentDefaultSettings):
pass
Expand All @@ -21,5 +22,15 @@ class SessionData(BaseModel):
agent_about: str
updated_at: float
created_at: float
model: ModelType
model: str
default_settings: SessionSettings

@validator('model')
def validate_model_type(cls, model):
if model not in ALL_AVAILABLE_MODELS.keys():
raise ValueError(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
)

return model
125 changes: 125 additions & 0 deletions agents-api/agents_api/model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
Model Registry maintains a list of supported models and their configs.
"""
from typing import Dict, List, Optional

GPT4_MODELS: Dict[str, int] = {
# stable model names:
# resolves to gpt-4-0314 before 2023-06-27,
# resolves to gpt-4-0613 after
"gpt-4": 8192,
"gpt-4-32k": 32768,
# turbo models (Turbo, JSON mode)
"gpt-4-1106-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
# multimodal model
"gpt-4-vision-preview": 128000,
# 0613 models (function calling):
# https://openai.com/blog/function-calling-and-other-api-updates
"gpt-4-0613": 8192,
"gpt-4-32k-0613": 32768,
# 0314 models
"gpt-4-0314": 8192,
"gpt-4-32k-0314": 32768,
}

TURBO_MODELS: Dict[str, int] = {
# stable model names:
# resolves to gpt-3.5-turbo-0301 before 2023-06-27,
# resolves to gpt-3.5-turbo-0613 until 2023-12-11,
# resolves to gpt-3.5-turbo-1106 after
"gpt-3.5-turbo": 4096,
# resolves to gpt-3.5-turbo-16k-0613 until 2023-12-11
# resolves to gpt-3.5-turbo-1106 after
"gpt-3.5-turbo-16k": 16384,
# 0125 (2024) model (JSON mode)
"gpt-3.5-turbo-0125": 16385,
# 1106 model (JSON mode)
"gpt-3.5-turbo-1106": 16384,
# 0613 models (function calling):
# https://openai.com/blog/function-calling-and-other-api-updates
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k-0613": 16384,
# 0301 models
"gpt-3.5-turbo-0301": 4096,
}

GPT3_5_MODELS: Dict[str, int] = {
"text-davinci-003": 4097,
"text-davinci-002": 4097,
# instruct models
"gpt-3.5-turbo-instruct": 4096,
}

GPT3_MODELS: Dict[str, int] = {
"text-ada-001": 2049,
"text-babbage-001": 2040,
"text-curie-001": 2049,
"ada": 2049,
"babbage": 2049,
"curie": 2049,
"davinci": 2049,
}


DISCONTINUED_MODELS = {
"code-davinci-002": 8001,
"code-davinci-001": 8001,
"code-cushman-002": 2048,
"code-cushman-001": 2048,
}

CLAUDE_MODELS: Dict[str, int] = {
"claude-instant-1": 100000,
"claude-instant-1.2": 100000,
"claude-2": 100000,
"claude-2.0": 100000,
"claude-2.1": 200000,
"claude-3-opus-20240229": 180000,
"claude-3-sonnet-20240229": 180000,
"claude-3-haiku-20240307": 180000,
}

CHAT_MODELS = {
**GPT4_MODELS,
**TURBO_MODELS,
**CLAUDE_MODELS
}

ALL_AVAILABLE_MODELS = {
**GPT4_MODELS,
**TURBO_MODELS,
**GPT3_5_MODELS,
**GPT3_MODELS,
**CLAUDE_MODELS,
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
}


def validate_configuration():
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
"""
function that validates the config based on the model
"""
pass


def validate_request():
"""
function that validates the config based on the model
"""
pass

def prepare_request():
"""
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
method that given the request in extended openai format
"""
pass

def parse_response():
"""
method that converts the response from the provider back into the openai format
"""
pass



6 changes: 4 additions & 2 deletions agents-api/agents_api/models/agent/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@

from ..instructions.create_instructions import create_instructions_query

from ...model_registry import ALL_AVAILABLE_MODELS

def create_agent_query(
agent_id: UUID,
developer_id: UUID,
name: str,
about: str,
model: str,
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
instructions: list[Instruction] = [],
model: str = "julep-ai/samantha-1-turbo",
metadata: dict = {},
default_settings: dict = {},
):
assert model in ["julep-ai/samantha-1", "julep-ai/samantha-1-turbo"]
# assert model in ["julep-ai/samantha-1", "julep-ai/samantha-1-turbo", "gpt-4"]
assert model in ALL_AVAILABLE_MODELS.keys()
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved

settings_cols, settings_vals = cozo_process_mutate_data(
{
Expand Down
35 changes: 19 additions & 16 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from agents_api.clients.embed import embed
from agents_api.common.utils.datetime import utcnow
from agents_api.common.exceptions.agents import (
AgentModelNotValid,
AgentNotFoundError,
AgentToolNotFoundError,
AgentDocNotFoundError,
Expand Down Expand Up @@ -119,7 +120,7 @@ async def update_agent(
).model_dump(),
name=request.name,
about=request.about,
model=request.model or "julep-ai/samantha-1-turbo",
model=request.model,
metadata=request.metadata,
instructions=request.instructions,
)
Expand Down Expand Up @@ -211,21 +212,23 @@ async def create_agent(
request: CreateAgentRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
resp = client.run(
create_agent_query(
agent_id=uuid4(),
developer_id=x_developer_id,
name=request.name,
about=request.about,
instructions=request.instructions,
model=request.model,
default_settings=(
request.default_settings or AgentDefaultSettings()
).model_dump(),
metadata=request.metadata or {},
),
)

try:
resp = client.run(
create_agent_query(
agent_id=uuid4(),
developer_id=x_developer_id,
name=request.name,
about=request.about,
instructions=request.instructions,
model=request.model,
default_settings=(
request.default_settings or AgentDefaultSettings()
).model_dump(),
metadata=request.metadata or {},
),
)
except AssertionError as e:
raise AgentModelNotValid(request.model)
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
new_agent_id = resp["agent_id"][0]
res = ResourceCreatedResponse(
id=new_agent_id,
Expand Down
20 changes: 11 additions & 9 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def forward(

# FIXME: This sometimes returns "The model `` does not exist."
if session_data is not None:
settings.model = session_data.model or "julep-ai/samantha-1-turbo"
settings.model = session_data.model

# Add tools to settings
if tools:
Expand All @@ -210,21 +210,23 @@ async def generate(
tools = None
if settings.tools:
tools = [tool.model_dump(mode="json") for tool in settings.tools]
from pprint import pprint
pprint(openai_client.base_url)
return await openai_client.chat.completions.create(
model=settings.model,
messages=init_context,
max_tokens=settings.max_tokens,
stop=settings.stop,
temperature=settings.temperature,
frequency_penalty=settings.frequency_penalty,
extra_body=dict(
repetition_penalty=settings.repetition_penalty,
best_of=1,
top_k=1,
length_penalty=settings.length_penalty,
logit_bias=settings.logit_bias,
preset=settings.preset.name if settings.preset else None,
),
# extra_body=dict(
alt-glitch marked this conversation as resolved.
Show resolved Hide resolved
# repetition_penalty=settings.repetition_penalty,
# best_of=1,
# top_k=1,
# length_penalty=settings.length_penalty,
# logit_bias=settings.logit_bias,
# preset=settings.preset.name if settings.preset else None,
# ),
top_p=settings.top_p,
presence_penalty=settings.presence_penalty,
stream=settings.stream,
Expand Down
2 changes: 0 additions & 2 deletions gateway/traefik.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ http:
X-Developer-Id: sub
X-Developer-Email: email
OpaHttpStatusField: allow_status_code
KeysWhitelist:
-


services:
Expand Down