Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

f/generic model support #24

Merged
merged 17 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ COZO_ROCKSDB_DIR=cozo.db
DTYPE=bfloat16
EMBEDDING_SERVICE_URL=http://text-embeddings-inference/embed
GATEWAY_PORT=80
GENERATION_AUTH_TOKEN=myauthkey
GENERATION_URL=http://model-serving:8000/v1
OPENAI_API_KEY=""
GPU_MEMORY_UTILIZATION=0.95
HF_TOKEN=""
HUGGING_FACE_HUB_TOKEN=""
Expand All @@ -21,17 +20,17 @@ GF_SECURITY_ADMIN_PASSWORD=changethis
MODEL_API_KEY=myauthkey
MODEL_API_KEY_HEADER_NAME=Authorization
MODEL_API_URL=http://model-serving:8000
MODEL_INFERENCE_URL=http://model-serving:8000/v1
MODEL_ID=BAAI/llm-embedder
MODEL_NAME=julep-ai/samantha-1-turbo
# MODEL_NAME = "julep-ai/samantha-1-turbo-awq"
MODEL_NAME = "julep-ai/samantha-1-turbo"
SKIP_CHECK_DEVELOPER_HEADERS=true
SUMMARIZATION_TOKENS_THRESHOLD=2048
TEMPERATURE_SCALING_FACTOR=0.9
TEMPERATURE_SCALING_POWER=0.9
TEMPORAL_ENDPOINT=temporal:7233
TEMPORAL_NAMESPACE=default
TEMPORAL_WORKER_URL=temporal:7233
TP_SIZE=2
TP_SIZE=1
TRUNCATE_EMBED_TEXT=true
TRAEFIK_LOG_LEVEL=DEBUG
WORKER_URL=temporal:7233
WORKER_URL=temporal:7233
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ ngrok*
*.env
*.pyc
*/node_modules/
.devcontainer
node_modules/
package-lock.json
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these gitignored? be careful

package.json
.aider*
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/co_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import MemoryDensityTaskArgs


Expand Down Expand Up @@ -63,7 +63,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(MemoryDensityTaskArgs(memory=memory))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/dialog_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import ChatML, DialogInsightsTaskArgs


Expand Down Expand Up @@ -66,7 +66,7 @@ async def run_prompt(
DialogInsightsTaskArgs(dialog=dialog, person1=person1, person2=person2)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/mem_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import ChatML, MemoryManagementTaskArgs


Expand Down Expand Up @@ -135,7 +135,7 @@ async def run_prompt(
)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import MemoryRatingTaskArgs


Expand Down Expand Up @@ -47,7 +47,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(MemoryRatingTaskArgs(memory=memory))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/relationship_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import RelationshipSummaryTaskArgs


Expand Down Expand Up @@ -49,7 +49,7 @@ async def run_prompt(
)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/salient_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import SalientQuestionsTaskArgs


Expand Down Expand Up @@ -40,7 +40,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(SalientQuestionsTaskArgs(statements=statements, num=num))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from agents_api.clients.openai import client as openai_client
from agents_api.clients.model import julep_client


example_previous_memory = """
Expand Down Expand Up @@ -130,7 +130,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(dialog, previous_memories, **kwargs)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
12 changes: 12 additions & 0 deletions agents-api/agents_api/clients/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from openai import AsyncOpenAI
from ..env import model_inference_url, model_api_key, openai_api_key


openai_client = AsyncOpenAI(
api_key=openai_api_key
)

julep_client = AsyncOpenAI(
base_url=model_inference_url,
api_key=model_api_key,
)
8 changes: 0 additions & 8 deletions agents-api/agents_api/clients/openai.py

This file was deleted.

9 changes: 8 additions & 1 deletion agents-api/agents_api/common/exceptions/agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from uuid import UUID
from . import BaseCommonException

from agents_api.model_registry import ALL_AVAILABLE_MODELS

class BaseAgentException(BaseCommonException):
pass
Expand All @@ -26,3 +26,10 @@ def __init__(self, agent_id: UUID | str, doc_id: UUID | str):
super().__init__(
f"Doc {str(doc_id)} not found for agent {str(agent_id)}", http_code=404
)

class AgentModelNotValid(BaseAgentException):
def __init__(self, model: str):
super().__init__(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys()), http_code=400
)
3 changes: 0 additions & 3 deletions agents-api/agents_api/common/protocol/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,3 @@ class AgentDefaultSettings(BaseModel):
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
min_p: float = 0.01


ModelType = Literal["julep-ai/samantha-1", "julep-ai/samantha-1-turbo"]
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
from typing import Literal
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, computed_field, validator
from agents_api.autogen.openapi_model import Role

Expand All @@ -21,6 +20,7 @@ class Entry(BaseModel):
created_at: float = Field(default_factory=lambda: datetime.utcnow().timestamp())
timestamp: float = Field(default_factory=lambda: datetime.utcnow().timestamp())


@computed_field
@property
def token_count(self) -> int:
Expand Down
17 changes: 14 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, validator

from .agents import ModelType, AgentDefaultSettings
from .agents import AgentDefaultSettings

from model_registry import ALL_AVAILABLE_MODELS

class SessionSettings(AgentDefaultSettings):
pass
Expand All @@ -21,5 +22,15 @@ class SessionData(BaseModel):
agent_about: str
updated_at: float
created_at: float
model: ModelType
model: str
default_settings: SessionSettings

@validator('model')
def validate_model_type(cls, model):
if model not in ALL_AVAILABLE_MODELS.keys():
raise ValueError(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
)

return model
7 changes: 3 additions & 4 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
prediction_api_endpoint: str = env.str(
"PREDICTION_API_ENDPOINT", default="us-central1-aiplatform.googleapis.com"
)
generation_url: str = env.str("GENERATION_URL", default=None)
generation_auth_token: str = env.str("GENERATION_AUTH_TOKEN", default=None)
model_api_key: str = env.str("MODEL_API_KEY", default=None)
model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default=None)
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
Expand Down Expand Up @@ -63,8 +64,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
generation_url=generation_url,
generation_auth_token=generation_auth_token,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
Expand Down
Loading