Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

f/generic model support #24

Merged
merged 17 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ COZO_ROCKSDB_DIR=cozo.db
DTYPE=bfloat16
EMBEDDING_SERVICE_URL=http://text-embeddings-inference/embed
GATEWAY_PORT=80
GENERATION_AUTH_TOKEN=myauthkey
GENERATION_URL=http://model-serving:8000/v1
OPENAI_API_KEY=""
GPU_MEMORY_UTILIZATION=0.95
HF_TOKEN=""
HUGGING_FACE_HUB_TOKEN=""
Expand All @@ -21,17 +20,17 @@ GF_SECURITY_ADMIN_PASSWORD=changethis
MODEL_API_KEY=myauthkey
MODEL_API_KEY_HEADER_NAME=Authorization
MODEL_API_URL=http://model-serving:8000
MODEL_INFERENCE_URL=http://model-serving:8000/v1
MODEL_ID=BAAI/llm-embedder
MODEL_NAME=julep-ai/samantha-1-turbo
# MODEL_NAME = "julep-ai/samantha-1-turbo-awq"
MODEL_NAME = "julep-ai/samantha-1-turbo"
SKIP_CHECK_DEVELOPER_HEADERS=true
SUMMARIZATION_TOKENS_THRESHOLD=2048
TEMPERATURE_SCALING_FACTOR=0.9
TEMPERATURE_SCALING_POWER=0.9
TEMPORAL_ENDPOINT=temporal:7233
TEMPORAL_NAMESPACE=default
TEMPORAL_WORKER_URL=temporal:7233
TP_SIZE=2
TP_SIZE=1
TRUNCATE_EMBED_TEXT=true
TRAEFIK_LOG_LEVEL=DEBUG
WORKER_URL=temporal:7233
WORKER_URL=temporal:7233
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ ngrok*
*.pyc
*/node_modules/
.aider*
.vscode/
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
- `summarization.py`: Summarizes dialogues and updates memory based on the conversation context.

This module plays a crucial role in enhancing the capabilities of agents by providing them with the tools to understand and process information more effectively.
"""
"""
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/co_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import MemoryDensityTaskArgs


Expand Down Expand Up @@ -63,7 +63,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(MemoryDensityTaskArgs(memory=memory))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/dialog_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import ChatML, DialogInsightsTaskArgs


Expand Down Expand Up @@ -66,7 +66,7 @@ async def run_prompt(
DialogInsightsTaskArgs(dialog=dialog, person1=person1, person2=person2)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/mem_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import ChatML, MemoryManagementTaskArgs


Expand Down Expand Up @@ -135,7 +135,7 @@ async def run_prompt(
)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/mem_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import MemoryRatingTaskArgs


Expand Down Expand Up @@ -47,7 +47,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(MemoryRatingTaskArgs(memory=memory))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/relationship_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import RelationshipSummaryTaskArgs


Expand Down Expand Up @@ -49,7 +49,7 @@ async def run_prompt(
)
)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/salient_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from temporalio import activity

from ..clients.openai import client as openai_client
from ..clients.model import julep_client
from .types import SalientQuestionsTaskArgs


Expand Down Expand Up @@ -40,7 +40,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(SalientQuestionsTaskArgs(statements=statements, num=num))

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from agents_api.clients.openai import client as openai_client
from agents_api.clients.model import julep_client


example_previous_memory = """
Expand Down Expand Up @@ -130,7 +130,7 @@ async def run_prompt(
) -> str:
prompt = make_prompt(dialog, previous_memories, **kwargs)

response = await openai_client.chat.completions.create(
response = await julep_client.chat.completions.create(
model=model,
messages=[
{
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
This module contains automatically generated models based on the OpenAPI specification for the agents-api project. It includes definitions for key entities such as Users, Sessions, Agents, Tools, and their respective interactions. These models play a crucial role in defining the structure and constraints of data exchanged with the API endpoints, ensuring consistency and validation across the service. Generated models cover a wide range of functionalities from user management, session handling, agent configuration, to tool definitions, providing a comprehensive schema for the API's operations.
"""
"""
2 changes: 1 addition & 1 deletion agents-api/agents_api/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
- `openai.py`: Facilitates interaction with OpenAI's API for natural language processing tasks.
- `temporal.py`: Provides functionality for connecting to Temporal workflows, enabling asynchronous task execution and management.
- `worker/__init__.py` and related files: Describe the role of the worker service client in sending tasks to be processed by an external worker service, focusing on memory management and other computational tasks.
"""
"""
10 changes: 10 additions & 0 deletions agents-api/agents_api/clients/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from openai import AsyncOpenAI
from ..env import model_inference_url, model_api_key, openai_api_key


openai_client = AsyncOpenAI(api_key=openai_api_key)

julep_client = AsyncOpenAI(
base_url=model_inference_url,
api_key=model_api_key,
)
8 changes: 0 additions & 8 deletions agents-api/agents_api/clients/openai.py

This file was deleted.

2 changes: 1 addition & 1 deletion agents-api/agents_api/clients/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
This module provides functionality for interacting with an external worker service. It includes utilities for creating and sending tasks, such as memory management tasks, to be processed by the service. The module leverages asynchronous HTTP requests via the `httpx` library to communicate with the worker service. Types for structuring task data are defined in `types.py`.
"""
"""
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

All custom exceptions extend from `BaseCommonException`, which encapsulates common attributes and behavior, including the error message and HTTP status code. This structured approach to exception handling facilitates precise and meaningful error feedback to API consumers, thereby improving the overall developer experience.
"""


class BaseCommonException(Exception):
def __init__(self, msg: str, http_code: int):
super().__init__(msg)
Expand Down
10 changes: 10 additions & 0 deletions agents-api/agents_api/common/exceptions/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from uuid import UUID
from . import BaseCommonException
from agents_api.model_registry import ALL_AVAILABLE_MODELS


class BaseAgentException(BaseCommonException):
Expand Down Expand Up @@ -39,3 +40,12 @@ def __init__(self, agent_id: UUID | str, doc_id: UUID | str):
super().__init__(
f"Doc {str(doc_id)} not found for agent {str(agent_id)}", http_code=404
)


class AgentModelNotValid(BaseAgentException):
def __init__(self, model: str):
super().__init__(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys()),
http_code=400,
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
- `SessionData`: Represents the data associated with a session, including identifiers and session-specific information such as situation, summary, and timestamps.

These components are crucial for the effective operation and interaction within the agents API.
"""
"""
5 changes: 0 additions & 5 deletions agents-api/agents_api/common/protocol/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Literal

from pydantic import BaseModel


Expand All @@ -20,6 +18,3 @@ class AgentDefaultSettings(BaseModel):
frequency_penalty: float = 0.0
"""Minimum probability threshold for including a word in the agent's response."""
min_p: float = 0.01


ModelType = Literal["julep-ai/samantha-1", "julep-ai/samantha-1-turbo"]
18 changes: 15 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"""
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, validator

from .agents import ModelType, AgentDefaultSettings
from .agents import AgentDefaultSettings

from agents_api.model_registry import ALL_AVAILABLE_MODELS


class SessionSettings(AgentDefaultSettings):
Expand Down Expand Up @@ -35,5 +37,15 @@ class SessionData(BaseModel):
agent_about: str
updated_at: float
created_at: float
model: ModelType
model: str
default_settings: SessionSettings

@validator("model")
def validate_model_type(cls, model):
if model not in ALL_AVAILABLE_MODELS.keys():
raise ValueError(
f"Unknown model: {model}. Please provide a valid model name."
"Known models are: " + ", ".join(ALL_AVAILABLE_MODELS.keys())
)

return model
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
- `json.py`: Custom JSON utilities, including a custom JSON encoder for handling specific object types like UUIDs, and a utility function for JSON serialization with support for default values for None objects.

These utilities are essential for the internal operations of the `agents-api`, providing common functionalities that are reused across different parts of the application.
"""
"""
2 changes: 1 addition & 1 deletion agents-api/agents_api/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
- `developer_id.py` for developer identification: Handles developer-specific headers like `X-Developer-Id` and `X-Developer-Email`, facilitating the identification of the developer making the request.
- `exceptions.py` for custom exception handling: Defines custom exceptions that are used throughout the dependencies module to handle errors related to API security and developer identification.

These components collectively ensure the security and proper operation of the agents-api by authenticating requests, identifying developers, and handling errors in a standardized manner."""
These components collectively ensure the security and proper operation of the agents-api by authenticating requests, identifying developers, and handling errors in a standardized manner."""
7 changes: 3 additions & 4 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
prediction_api_endpoint: str = env.str(
"PREDICTION_API_ENDPOINT", default="us-central1-aiplatform.googleapis.com"
)
generation_url: str = env.str("GENERATION_URL", default=None)
generation_auth_token: str = env.str("GENERATION_AUTH_TOKEN", default=None)
model_api_key: str = env.str("MODEL_API_KEY", default=None)
model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default=None)
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
Expand Down Expand Up @@ -63,8 +64,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
generation_url=generation_url,
generation_auth_token=generation_auth_token,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
Expand Down
Loading