Skip to content

Commit

Permalink
refactor(agents-api): Add more types
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 17, 2024
1 parent 2fbd877 commit 010660b
Show file tree
Hide file tree
Showing 93 changed files with 369 additions and 99 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from temporalio import activity

from ..env import testing
Expand All @@ -12,6 +14,6 @@ async def mock_demo_activity(a: int, b: int) -> int:
return a + b


demo_activity = activity.defn(name="demo_activity")(
demo_activity: Callable[[int, int], int] = activity.defn(name="demo_activity")(
demo_activity if not testing else mock_demo_activity
)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -28,7 +28,7 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or get_cozo_client(),
client=cozo_client or cozo.get_cozo_client(),
)


Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import TextIO

logger = logging.getLogger(__name__)
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
logger: logging.Logger = logging.getLogger(__name__)
h: logging.StreamHandler[TextIO] = logging.StreamHandler()
fmt: logging.Formatter = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


# TODO: remove stubs
def entries_summarization_query(*args, **kwargs):
def entries_summarization_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


def get_toplevel_entries_query(*args, **kwargs):
def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from beartype import beartype
from temporalio import activity

Expand Down Expand Up @@ -33,8 +35,8 @@ async def yield_step(context: StepContext[YieldStep]) -> StepOutcome:

# Note: This is here just for clarity. We could have just imported yield_step directly
# They do the same thing, so we dont need to mock the yield_step function
mock_yield_step = yield_step
mock_yield_step: Callable[[StepContext], StepOutcome] = yield_step

yield_step = activity.defn(name="yield_step")(
yield_step: Callable[[StepContext], StepOutcome] = activity.defn(name="yield_step")(
yield_step if not testing else mock_yield_step
)
6 changes: 4 additions & 2 deletions agents-api/agents_api/clients/cozo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Any, Dict

from pycozo.client import Client

from ..env import cozo_auth, cozo_host
from ..web import app

options = {"host": cozo_host}
options: Dict[str, str] = {"host": cozo_host}
if cozo_auth:
options.update({"auth": cozo_auth})


def get_cozo_client():
def get_cozo_client() -> Any:
client = getattr(app.state, "cozo_client", Client("http", options=options))
if not hasattr(app.state, "cozo_client"):
app.state.cozo_client = client
Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from functools import wraps
from typing import List, TypeVar

from litellm import acompletion as _acompletion
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url

__all__ = ["acompletion"]
_RWrapped = TypeVar("_RWrapped")

__all__: List[str] = ["acompletion"]


@wraps(_acompletion)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@


class BaseCommonException(Exception):
def __init__(self, msg: str, http_code: int):
def __init__(self, msg: str, http_code: int) -> None:
super().__init__(msg)
self.http_code = http_code
4 changes: 2 additions & 2 deletions agents-api/agents_api/common/utils/cozo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pycozo import Client

# Define a mock client for testing purposes, simulating Cozo API client behavior.
_fake_client = SimpleNamespace()
_fake_client: SimpleNamespace = SimpleNamespace()
# Lambda function to process and mutate data dictionaries using the Cozo client's internal method. This is a workaround to access protected member functions for testing.
_fake_client._process_mutate_data_dict = lambda data: (
Client._process_mutate_data_dict(_fake_client, data)
Expand All @@ -20,5 +20,5 @@
)


def uuid_int_list_to_uuid4(data):
def uuid_int_list_to_uuid4(data) -> UUID:
return UUID(bytes=b"".join([i.to_bytes(1, "big") for i in data]))
6 changes: 3 additions & 3 deletions agents-api/agents_api/common/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class CustomJSONEncoder(json.JSONEncoder):
"""A custom JSON encoder subclass that handles None values and UUIDs for JSON serialization. It allows specifying a default value for None objects during initialization."""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""Initializes the custom JSON encoder.
Parameters:
*args: Variable length argument list.
Expand All @@ -19,15 +19,15 @@ def __init__(self, *args, **kwargs):
self._default_empty_value = kwargs.pop("default_empty_value")
super().__init__(*args, **kwargs)

def encode(self, o):
def encode(self, o) -> str:
"""Encodes the given object into a JSON formatted string.
Parameters:
o: The object to encode.
Returns: A JSON formatted string representing 'o'."""
# Use the overridden default method for serialization before encoding
return super().encode(self.default(o))

def default(self, obj):
def default(self, obj) -> Any:
"""Provides a default serialization for objects that the standard JSON encoder cannot serialize.
Parameters:
obj: The object to serialize.
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List

import arrow
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2schema import infer, to_json_schema
from jsonschema import validate

__all__ = [
__all__: List[str] = [
"render_template",
]

# jinja environment
jinja_env = ImmutableSandboxedEnvironment(
jinja_env: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
Expand Down
9 changes: 7 additions & 2 deletions agents-api/agents_api/dependencies/auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import Any

from fastapi import HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from starlette.status import HTTP_403_FORBIDDEN

from ..env import api_key, api_key_header_name

api_key_header = APIKeyHeader(name=api_key_header_name, auto_error=False)
api_key_header: Any = APIKeyHeader(name=api_key_header_name, auto_error=False)


async def get_api_key(user_api_key: str = Security(api_key_header)):
async def get_api_key(
user_api_key: str = Security(api_key_header),
) -> str:
user_api_key = str(user_api_key)
user_api_key = (user_api_key or "").replace("Bearer ", "").strip()

if user_api_key != api_key:
Expand Down
11 changes: 6 additions & 5 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import random
from pprint import pprint
from typing import Any, Dict

from environs import Env

# Initialize the Env object for environment variable parsing.
env = Env()
env: Any = Env()


# Debug
Expand All @@ -30,7 +31,7 @@

# Auth
# ----
_random_generated_key = "".join(str(random.randint(0, 9)) for _ in range(32))
_random_generated_key: str = "".join(str(random.randint(0, 9)) for _ in range(32))
api_key: str = env.str("AGENTS_API_KEY", _random_generated_key)

if api_key == _random_generated_key:
Expand Down Expand Up @@ -65,12 +66,12 @@
temporal_namespace: str = env.str("TEMPORAL_NAMESPACE", default="default")
temporal_client_cert: str = env.str("TEMPORAL_CLIENT_CERT", default=None)
temporal_private_key: str = env.str("TEMPORAL_PRIVATE_KEY", default=None)
temporal_endpoint = env.str("TEMPORAL_ENDPOINT", default="localhost:7233")
temporal_task_queue = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue")
temporal_endpoint: Any = env.str("TEMPORAL_ENDPOINT", default="localhost:7233")
temporal_task_queue: Any = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue")


# Consolidate environment variables
environment = dict(
environment: Dict[str, Any] = dict(
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ class AgentsBaseException(Exception):


class ModelNotSupportedError(AgentsBaseException):
def __init__(self, model_name):
def __init__(self, model_name) -> None:
super().__init__(f"model {model_name} is not supported")


class PromptTooBigError(AgentsBaseException):
def __init__(self, token_count, max_tokens):
def __init__(self, token_count, max_tokens) -> None:
super().__init__(
f"prompt is too big, {token_count} tokens provided, exceeds maximum of {max_tokens}"
)


class UnknownTokenizerError(AgentsBaseException):
def __init__(self):
def __init__(self) -> None:
super().__init__("unknown tokenizer")
17 changes: 11 additions & 6 deletions agents-api/agents_api/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
}


DISCONTINUED_MODELS = {
DISCONTINUED_MODELS: Dict[str, int] = {
"code-davinci-002": 8001,
"code-davinci-001": 8001,
"code-cushman-002": 2048,
Expand All @@ -84,9 +84,14 @@
"claude-3-haiku-20240307": 180000,
}

OPENAI_MODELS = {**GPT4_MODELS, **TURBO_MODELS, **GPT3_5_MODELS, **GPT3_MODELS}
OPENAI_MODELS: Dict[str, int] = {
**GPT4_MODELS,
**TURBO_MODELS,
**GPT3_5_MODELS,
**GPT3_MODELS,
}

LOCAL_MODELS = {
LOCAL_MODELS: Dict[str, int] = {
"gpt-4o": 32768,
"gpt-4o-awq": 32768,
"TinyLlama/TinyLlama_v1.1": 2048,
Expand All @@ -95,13 +100,13 @@
"OpenPipe/Hermes-2-Theta-Llama-3-8B-32k": 32768,
}

LOCAL_MODELS_WITH_TOOL_CALLS = {
LOCAL_MODELS_WITH_TOOL_CALLS: Dict[str, int] = {
"OpenPipe/Hermes-2-Theta-Llama-3-8B-32k": 32768,
"julep-ai/Hermes-2-Theta-Llama-3-8B": 8192,
}

OLLAMA_MODELS = {
OLLAMA_MODELS: Dict[str, int] = {
"llama2": 4096,
}

CHAT_MODELS = {**GPT4_MODELS, **TURBO_MODELS, **CLAUDE_MODELS}
CHAT_MODELS: Dict[str, int] = {**GPT4_MODELS, **TURBO_MODELS, **CLAUDE_MODELS}
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/agent/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It includes functions to construct and execute datalog queries for inserting new agent records.
"""

from typing import Any, TypeVar
from uuid import UUID, uuid4

from beartype import beartype
Expand All @@ -20,6 +21,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/agent/create_or_update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It includes functions to construct and execute datalog queries for inserting new agent records.
"""

from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
Expand All @@ -20,6 +21,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/agent/delete_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module contains the implementation of the delete_agent_query function, which is responsible for deleting an agent and its related default settings from the CozoDB database.
"""

from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
Expand All @@ -20,6 +21,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/agent/get_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
Expand All @@ -15,6 +16,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/models/agent/list_agents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any, Literal, TypeVar
from uuid import UUID

from beartype import beartype
Expand All @@ -16,6 +16,9 @@
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
Expand Down
Loading

0 comments on commit 010660b

Please sign in to comment.