Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt committed Nov 4, 2024
1 parent f53c503 commit e8049a4
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 128 deletions.
9 changes: 4 additions & 5 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import get_env_url
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status, raise_for_status
from exchange.langfuse_wrapper import observe_wrapper

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -24,6 +21,8 @@ class AnthropicProvider(Provider):
"""Provides chat completions for models hosted directly by Anthropic."""

PROVIDER_NAME = "anthropic"
BASE_URL_ENV_VAR = "ANTHROPIC_HOST"
BASE_URL_DEFAULT = "https://api.anthropic.com/v1/messages"
REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"]

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -32,7 +31,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = get_env_url("ANTHROPIC_HOST", ANTHROPIC_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("ANTHROPIC_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down Expand Up @@ -165,5 +164,5 @@ def recommended_models() -> tuple[str, str]:

@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
response = self.client.post(self.BASE_URL_DEFAULT, json=payload)
return raise_for_status(response).json()
5 changes: 2 additions & 3 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os

from exchange.providers import OpenAiProvider
from exchange.providers.utils import get_env_url


class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service."""

PROVIDER_NAME = "azure"
BASE_URL_ENV_VAR = "AZURE_CHAT_COMPLETIONS_HOST_NAME"
REQUIRED_ENV_VARS = [
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
Expand All @@ -22,7 +21,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
cls.check_env_vars()
url = get_env_url("AZURE_CHAT_COMPLETIONS_HOST_NAME")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION")
key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY")
Expand Down
17 changes: 16 additions & 1 deletion packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import httpx
import os
from abc import ABC, abstractmethod
from attrs import define, field
Expand All @@ -22,6 +23,8 @@ def __init__(self, provider_cls: str) -> None:

class Provider(ABC):
PROVIDER_NAME: str
BASE_URL_ENV_VAR: str = ""
BASE_URL_DEFAULT: str = ""
REQUIRED_ENV_VARS: list[str] = []

@classmethod
Expand All @@ -32,11 +35,23 @@ def from_env(cls: type["Provider"]) -> "Provider":

@classmethod
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
provider = cls.PROVIDER_NAME
missing_vars = [x for x in cls.REQUIRED_ENV_VARS if x not in os.environ]

url_var = cls.BASE_URL_ENV_VAR
if url_var:
val = os.environ.get(url_var, cls.BASE_URL_DEFAULT)
if not val:
raise KeyError(url_var)
else:
url = httpx.URL(val)

if url.scheme not in ["http", "https"]:
raise ValueError(f"Expected {url_var} to be a 'http' or 'https' url: {val}")

if missing_vars:
env_vars = ", ".join(missing_vars)
raise MissingProviderEnvVariableError(env_vars, cls.PROVIDER_NAME, instructions_url)
raise MissingProviderEnvVariableError(env_vars, provider, instructions_url)

@abstractmethod
def complete(
Expand Down
10 changes: 4 additions & 6 deletions packages/exchange/src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status, get_env_url
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -31,10 +31,8 @@ class DatabricksProvider(Provider):
"""

PROVIDER_NAME = "databricks"
REQUIRED_ENV_VARS = [
"DATABRICKS_HOST",
"DATABRICKS_TOKEN",
]
BASE_URL_ENV_VAR = "DATABRICKS_HOST"
REQUIRED_ENV_VARS = ["DATABRICKS_TOKEN"]
instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"

def __init__(self, client: httpx.Client) -> None:
Expand All @@ -43,7 +41,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("DATABRICKS_HOST")
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR))
key = os.environ.get("DATABRICKS_TOKEN")
client = httpx.Client(
base_url=url,
Expand Down
9 changes: 4 additions & 5 deletions packages/exchange/src/exchange/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status, get_env_url, encode_image
from exchange.providers.utils import raise_for_status, retry_if_status, encode_image
from exchange.langfuse_wrapper import observe_wrapper


GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
Expand All @@ -24,6 +21,8 @@ class GoogleProvider(Provider):
"""Provides chat completions for models hosted by Google, including Gemini and other experimental models."""

PROVIDER_NAME = "google"
BASE_URL_ENV_VAR = "GOOGLE_HOST"
BASE_URL_DEFAULT = "https://generativelanguage.googleapis.com/v1beta"
REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"]
instructions_url = "https://ai.google.dev/gemini-api/docs/api-key"

Expand All @@ -33,7 +32,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("GOOGLE_HOST", GOOGLE_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GOOGLE_API_KEY")
client = httpx.Client(
base_url=url,
Expand Down
7 changes: 3 additions & 4 deletions packages/exchange/src/exchange/providers/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
get_env_url,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

GROQ_HOST = "https://api.groq.com/openai/"

retry_procedure = retry(
wait=wait_fixed(5),
stop=stop_after_attempt(5),
Expand All @@ -31,6 +28,8 @@ class GroqProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "groq"
BASE_URL_ENV_VAR = "GROQ_HOST"
BASE_URL_DEFAULT = "https://api.groq.com/openai/"
REQUIRED_ENV_VARS = ["GROQ_API_KEY"]
instructions_url = "https://console.groq.com/docs/quickstart"

Expand All @@ -40,7 +39,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["GroqProvider"]) -> "GroqProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("GROQ_HOST", GROQ_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("GROQ_API_KEY")

client = httpx.Client(
Expand Down
8 changes: 5 additions & 3 deletions packages/exchange/src/exchange/providers/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from typing import Type
from exchange.providers.openai import OpenAiProvider
from exchange.providers.utils import get_env_url

OLLAMA_HOST = "http://localhost:11434/"
OLLAMA_MODEL = "qwen2.5"


Expand All @@ -26,14 +24,18 @@ class OllamaProvider(OpenAiProvider):
requires: {{}}
"""
PROVIDER_NAME = "ollama"
BASE_URL_ENV_VAR = "OLLAMA_HOST"
BASE_URL_DEFAULT = "http://localhost:11434/"
REQUIRED_ENV_VARS = []

def __init__(self, client: httpx.Client) -> None:
print("PLEASE NOTE: the ollama provider is experimental, use with care")
super().__init__(client)

@classmethod
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
ollama_url = get_env_url("OLLAMA_HOST", OLLAMA_HOST)
cls.check_env_vars(cls.instructions_url)
ollama_url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
timeout = httpx.Timeout(60 * 10)

# from_env is expected to fail if required ENV variables are not
Expand Down
6 changes: 3 additions & 3 deletions packages/exchange/src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
get_env_url,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.langfuse_wrapper import observe_wrapper

OPENAI_HOST = "https://api.openai.com/"

retry_procedure = retry(
wait=wait_fixed(2),
Expand All @@ -31,6 +29,8 @@ class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI."""

PROVIDER_NAME = "openai"
BASE_URL_ENV_VAR = "OPENAI_HOST"
BASE_URL_DEFAULT = "https://api.openai.com/"
REQUIRED_ENV_VARS = ["OPENAI_API_KEY"]
instructions_url = "https://platform.openai.com/docs/api-reference/api-keys"

Expand All @@ -40,7 +40,7 @@ def __init__(self, client: httpx.Client) -> None:
@classmethod
def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider":
cls.check_env_vars(cls.instructions_url)
url = get_env_url("OPENAI_HOST", OPENAI_HOST)
url = httpx.URL(os.environ.get(cls.BASE_URL_ENV_VAR, cls.BASE_URL_DEFAULT))
key = os.environ.get("OPENAI_API_KEY")

client = httpx.Client(
Expand Down
22 changes: 0 additions & 22 deletions packages/exchange/src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import json
import os
import re
from typing import Optional

Expand All @@ -11,27 +10,6 @@
from tenacity import retry_if_exception


def get_env_url(key: str, default: str = "") -> httpx.URL:
"""
Returns a valid 'http' or 'https' URL.
:param key: The environment key
:param default: The URL default value
:raises ValueError: If the URL scheme is not 'http' or 'https'
"""

val = os.environ.get(key, default)
if val == "":
raise ValueError(f"{key} was empty")

url = httpx.URL(val)

if url.scheme not in ["http", "https"]:
raise ValueError(f"expected {key} to be a 'http' or 'https' url: {val}")

return url


def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable:
codes = codes or []

Expand Down
2 changes: 1 addition & 1 deletion packages/exchange/tests/providers/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_from_env_throw_error_when_invalid_host(monkeypatch):
monkeypatch.setenv("ANTHROPIC_HOST", "localhost:1234")
monkeypatch.setenv("ANTHROPIC_API_KEY", "test_api_key")

with pytest.raises(ValueError, match="expected ANTHROPIC_HOST to be a 'http' or 'https' url: localhost:1234"):
with pytest.raises(ValueError, match="Expected ANTHROPIC_HOST to be a 'http' or 'https' url: localhost:1234"):
AnthropicProvider.from_env()


Expand Down
4 changes: 1 addition & 3 deletions packages/exchange/tests/providers/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ def test_from_env_throw_error_when_invalid_host(monkeypatch):
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", "test_api_key")

with pytest.raises(
ValueError, match="expected AZURE_CHAT_COMPLETIONS_HOST_NAME to be a 'http' or 'https' url: localhost:1234"
ValueError, match="Expected AZURE_CHAT_COMPLETIONS_HOST_NAME to be a 'http' or 'https' url: localhost:1234"
):
AzureProvider.from_env()


@pytest.mark.parametrize(
"env_var_name",
[
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
Expand All @@ -36,7 +35,6 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name):
with patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
Expand Down
Loading

0 comments on commit e8049a4

Please sign in to comment.