Skip to content

Commit

Permalink
Convention based loading of Provider modules. (#39)
Browse files Browse the repository at this point in the history
* Convention based loading of Provider modules.

Loads the Provider class based on below
convention.
Eg:
  For "aws:model-name",
  1) look for providers/aws_provider.py
  2) load AwsProvider class from above file.

This allows convention based addition of
new providers.
  • Loading branch information
rohitprasad15 authored Oct 3, 2024
1 parent ad46773 commit f2f05a5
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 118 deletions.
1 change: 0 additions & 1 deletion aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .client import Client
from .provider import ProviderNames
48 changes: 21 additions & 27 deletions aisuite/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .provider import ProviderFactory, ProviderNames
from .provider import ProviderFactory


class Client:
Expand All @@ -9,11 +9,11 @@ def __init__(self, provider_configs: dict = {}):
Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a ProviderNames enum or its string representation,
Each key should be a provider string (e.g., "google" or "aws-bedrock"),
and the value should be a dictionary of configuration options for that provider.
For example:
{
ProviderNames.OPENAI: {"api_key": "your_openai_api_key"},
"openai": {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
Expand All @@ -30,26 +30,23 @@ def _initialize_providers(self):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self.providers[provider_key.value] = ProviderFactory.create_provider(
self.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key is part of ProviderNames enum.
Allow strings as well and convert them to ProviderNames.
Validate if the provider key corresponds to a supported provider.
"""
if isinstance(provider_key, str):
if provider_key not in ProviderNames._value2member_map_:
raise ValueError(f"Provider {provider_key} is not a valid provider")
return ProviderNames(provider_key)
supported_providers = ProviderFactory.get_supported_providers()

if isinstance(provider_key, ProviderNames):
return provider_key
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

raise ValueError(
f"Provider {provider_key} should either be a string or enum ProviderNames"
)
return provider_key

def configure(self, provider_configs: dict = None):
"""
Expand Down Expand Up @@ -94,30 +91,27 @@ def create(self, model: str, messages: list, **kwargs):
f"Invalid model format. Expected 'provider:model', got '{model}'"
)

# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)

if provider_key not in ProviderNames._value2member_map_:
# If the provider key does not match, give a clearer message to guide the user
valid_providers = ", ".join([p.value for p in ProviderNames])
# Validate if the provider is supported
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. "
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = {}
if provider_key in self.client.provider_configs:
config = self.client.provider_configs[provider_key]
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
ProviderNames(provider_key), config
provider_key, config
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for {provider_key}.")
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the chat completion to the correct provider's implementation
# Any additional arguments will be passed to the provider's implementation.
# Eg: max_tokens, temperature, etc.
return provider.chat_completions_create(model_name, messages, **kwargs)
81 changes: 20 additions & 61 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
import importlib
import os
import functools


class LLMError(Exception):
Expand All @@ -17,78 +19,35 @@ def chat_completions_create(self, model, messages):
pass


class ProviderNames(str, Enum):
ANTHROPIC = "anthropic"
AWS = "aws"
AZURE = "azure"
FIREWORKS = "fireworks"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
MISTRAL = "mistral"
OLLAMA = "ollama"
OPENAI = "openai"
TOGETHER = "together"


class ProviderFactory:
"""Factory to register and create provider instances based on keys."""
"""Factory to dynamically load provider instances based on naming conventions."""

_provider_info = {
ProviderNames.ANTHROPIC: (
"aisuite.providers.anthropic_provider",
"AnthropicProvider",
),
ProviderNames.AWS: (
"aisuite.providers.aws_bedrock_provider",
"AWSBedrockProvider",
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
ProviderNames.HUGGINGFACE: (
"aisuite.providers.huggingface_provider",
"HuggingFaceProvider",
),
ProviderNames.MISTRAL: (
"aisuite.providers.mistral_provider",
"MistralProvider",
),
ProviderNames.OLLAMA: ("aisuite.providers.ollama_provider", "OllamaProvider"),
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
ProviderNames.FIREWORKS: (
"aisuite.providers.fireworks_provider",
"FireworksProvider",
),
ProviderNames.TOGETHER: (
"aisuite.providers.together_provider",
"TogetherProvider",
),
}
PROVIDERS_DIR = Path(__file__).parent / "providers"

@classmethod
def create_provider(cls, provider_key, config):
"""Dynamically import and create an instance of a provider based on the provider key."""
if not isinstance(provider_key, ProviderNames):
raise ValueError(
f"Provider {provider_key} is not a valid ProviderNames enum"
)
"""Dynamically load and create an instance of a provider based on the naming convention."""
# Convert provider_key to the expected module and class names
provider_class_name = f"{provider_key.capitalize()}Provider"
provider_module_name = f"{provider_key}_provider"

module_name, class_name = cls._get_provider_info(provider_key)
if not module_name:
raise ValueError(f"Provider {provider_key.value} is not supported")
module_path = f"aisuite.providers.{provider_module_name}"

# Lazily load the module
try:
module = importlib.import_module(module_name)
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(f"Could not import module {module_name}: {str(e)}")
raise ImportError(
f"Could not import module {module_path}: {str(e)}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()"
)

# Instantiate the provider class
provider_class = getattr(module, class_name)
provider_class = getattr(module, provider_class_name)
return provider_class(**config)

@classmethod
def _get_provider_info(cls, provider_key):
"""Return the module name and class name for a given provider key."""
return cls._provider_info.get(provider_key, (None, None))
@functools.cache
def get_supported_providers(cls):
"""List all supported provider names based on files present in the providers directory."""
provider_files = Path(cls.PROVIDERS_DIR).glob("*_provider.py")
return {file.stem.replace("_provider", "") for file in provider_files}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aisuite.framework import ChatCompletionResponse


class AWSBedrockProvider(Provider):
class AwsProvider(Provider):
def __init__(self, **config):
"""
Initialize the AWS Bedrock provider with the given configuration.
Expand Down
2 changes: 1 addition & 1 deletion aisuite/providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from aisuite.provider import Provider, LLMError


class OpenAIProvider(Provider):
class OpenaiProvider(Provider):
def __init__(self, **config):
"""
Initialize the OpenAI provider with the given configuration.
Expand Down
50 changes: 23 additions & 27 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import unittest
from unittest.mock import patch
from aisuite import Client
from aisuite import ProviderNames


class TestClient(unittest.TestCase):
@patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create")
@patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create")
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
@patch(
"aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create"
)
@patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create")
@patch("aisuite.providers.aws_provider.AwsProvider.chat_completions_create")
@patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create")
@patch(
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
Expand Down Expand Up @@ -42,29 +39,29 @@ def test_client_chat_completions(

# Provider configurations
provider_configs = {
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
ProviderNames.AWS: {
"openai": {"api_key": "test_openai_api_key"},
"aws": {
"aws_access_key": "test_aws_access_key",
"aws_secret_key": "test_aws_secret_key",
"aws_session_token": "test_aws_session_token",
"aws_region": "us-west-2",
},
ProviderNames.AZURE: {
"azure": {
"api_key": "azure-api-key",
"base_url": "https://model.ai.azure.com",
},
ProviderNames.GROQ: {
"groq": {
"api_key": "groq-api-key",
},
ProviderNames.MISTRAL: {
"mistral": {
"api_key": "mistral-api-key",
},
ProviderNames.GOOGLE: {
"google": {
"project_id": "test_google_project_id",
"region": "us-west4",
"application_credentials": "test_google_application_credentials",
},
ProviderNames.FIREWORKS: {
"fireworks": {
"api_key": "fireworks-api-key",
},
}
Expand All @@ -78,59 +75,59 @@ def test_client_chat_completions(
]

# Test OpenAI model
open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o"
open_ai_model = "openai" + ":" + "gpt-4o"
openai_response = client.chat.completions.create(
open_ai_model, messages=messages
)
self.assertEqual(openai_response, "OpenAI Response")
mock_openai.assert_called_once()

# Test AWS Bedrock model
bedrock_model = ProviderNames.AWS + ":" + "claude-v3"
bedrock_model = "aws" + ":" + "claude-v3"
bedrock_response = client.chat.completions.create(
bedrock_model, messages=messages
)
self.assertEqual(bedrock_response, "AWS Bedrock Response")
mock_bedrock.assert_called_once()

# Test Azure model
azure_model = ProviderNames.AZURE + ":" + "azure-model"
azure_model = "azure" + ":" + "azure-model"
azure_response = client.chat.completions.create(azure_model, messages=messages)
self.assertEqual(azure_response, "Azure Response")
mock_azure.assert_called_once()

# Test Anthropic model
anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model"
anthropic_model = "anthropic" + ":" + "anthropic-model"
anthropic_response = client.chat.completions.create(
anthropic_model, messages=messages
)
self.assertEqual(anthropic_response, "Anthropic Response")
mock_anthropic.assert_called_once()

# Test Groq model
groq_model = ProviderNames.GROQ + ":" + "groq-model"
groq_model = "groq" + ":" + "groq-model"
groq_response = client.chat.completions.create(groq_model, messages=messages)
self.assertEqual(groq_response, "Groq Response")
mock_groq.assert_called_once()

# Test Mistral model
mistral_model = ProviderNames.MISTRAL + ":" + "mistral-model"
mistral_model = "mistral" + ":" + "mistral-model"
mistral_response = client.chat.completions.create(
mistral_model, messages=messages
)
self.assertEqual(mistral_response, "Mistral Response")
mock_mistral.assert_called_once()

# Test Google model
google_model = ProviderNames.GOOGLE + ":" + "google-model"
google_model = "google" + ":" + "google-model"
google_response = client.chat.completions.create(
google_model, messages=messages
)
self.assertEqual(google_response, "Google Response")
mock_google.assert_called_once()

# Test Fireworks model
fireworks_model = ProviderNames.FIREWORKS + ":" + "fireworks-model"
fireworks_model = "fireworks" + ":" + "fireworks-model"
fireworks_response = client.chat.completions.create(
fireworks_model, messages=messages
)
Expand All @@ -142,11 +139,10 @@ def test_client_chat_completions(
next_compl_instance = client.chat.completions
assert compl_instance is next_compl_instance

@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
def test_invalid_provider_in_client_config(self, mock_openai):
def test_invalid_provider_in_client_config(self):
# Testing an invalid provider name in the configuration
invalid_provider_configs = {
"INVALID_PROVIDER": {"api_key": "invalid_api_key"},
"invalid_provider": {"api_key": "invalid_api_key"},
}

# Expect ValueError when initializing Client with invalid provider
Expand All @@ -155,19 +151,19 @@ def test_invalid_provider_in_client_config(self, mock_openai):

# Verify the error message
self.assertIn(
"Provider INVALID_PROVIDER is not a valid provider",
"Invalid provider key 'invalid_provider'. Supported providers: ",
str(context.exception),
)

@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
@patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create")
def test_invalid_model_format_in_create(self, mock_openai):
# Valid provider configurations
provider_configs = {
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
"openai": {"api_key": "test_openai_api_key"},
}

# Initialize the client with valid provider
client = Client(provider_configs)
client = Client()
client.configure(provider_configs)

messages = [
Expand Down

0 comments on commit f2f05a5

Please sign in to comment.