Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring client and providers. #27

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
.vscode/
__pycache__/
env/
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import Client
from .provider import ProviderNames
123 changes: 123 additions & 0 deletions aisuite/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from .provider import ProviderFactory, ProviderNames


class Client:
def __init__(self, provider_configs: dict = {}):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.

Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a ProviderNames enum or its string representation,
and the value should be a dictionary of configuration options for that provider.
For example:
{
ProviderNames.OPENAI: {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"aws_region": "us-west-2"
}
}
"""
self.providers = {}
self.provider_configs = provider_configs
self._chat = None
self._initialize_providers()

def _initialize_providers(self):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self.providers[provider_key.value] = ProviderFactory.create_provider(
provider_key, config
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key is part of ProviderNames enum.
Allow strings as well and convert them to ProviderNames.
"""
if isinstance(provider_key, str):
if provider_key not in ProviderNames._value2member_map_:
raise ValueError(f"Provider {provider_key} is not a valid provider")
return ProviderNames(provider_key)

if isinstance(provider_key, ProviderNames):
return provider_key

raise ValueError(
f"Provider {provider_key} should either be a string or enum ProviderNames"
)

def configure(self, provider_configs: dict = None):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return

self.provider_configs.update(provider_configs)
self._initialize_providers() # NOTE: This will override existing provider instances.

@property
def chat(self):
"""Return the chat API interface."""
if not self._chat:
self._chat = Chat(self)
return self._chat


class Chat:
def __init__(self, client: "Client"):
self.client = client
self._completions = Completions(self.client)

@property
def completions(self):
"""Return the completions interface."""
return self._completions


class Completions:
def __init__(self, client: "Client"):
self.client = client

def create(self, model: str, messages: list, **kwargs):
"""
Create chat completion based on the model, messages, and any extra arguments.
"""
# Check that correct format is used
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
)
ksolo marked this conversation as resolved.
Show resolved Hide resolved

# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
provider_key, model_name = model.split(":", 1)

if provider_key not in ProviderNames._value2member_map_:
# If the provider key does not match, give a clearer message to guide the user
valid_providers = ", ".join([p.value for p in ProviderNames])
raise ValueError(
f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

if provider_key not in self.client.providers:
config = {}
if provider_key in self.client.provider_configs:
config = self.client.provider_configs[provider_key]
self.client.providers[provider_key] = ProviderFactory.create_provider(
ProviderNames(provider_key), config
)
ksolo marked this conversation as resolved.
Show resolved Hide resolved

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for {provider_key}.")

# Delegate the chat completion to the correct provider's implementation
# Any additional arguments will be passed to the provider's implementation.
# Eg: max_tokens, temperature, etc.
return provider.chat_completions_create(model_name, messages, **kwargs)
3 changes: 0 additions & 3 deletions aisuite/client/__init__.py

This file was deleted.

18 changes: 0 additions & 18 deletions aisuite/client/chat.py

This file was deleted.

90 changes: 0 additions & 90 deletions aisuite/client/client.py

This file was deleted.

37 changes: 0 additions & 37 deletions aisuite/client/completions.py

This file was deleted.

2 changes: 0 additions & 2 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""

from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
68 changes: 68 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from enum import Enum
import importlib


class LLMError(Exception):
"""Custom exception for LLM errors."""

def __init__(self, message):
super().__init__(message)


class Provider(ABC):
@abstractmethod
def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass


class ProviderNames(str, Enum):
OPENAI = "openai"
AWS_BEDROCK = "aws-bedrock"
ANTHROPIC = "anthropic"
AZURE = "azure"


class ProviderFactory:
"""Factory to register and create provider instances based on keys."""

_provider_info = {
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
ProviderNames.AWS_BEDROCK: (
"aisuite.providers.aws_bedrock_provider",
"AWSBedrockProvider",
),
ProviderNames.ANTHROPIC: (
"aisuite.providers.anthropic_provider",
"AnthropicProvider",
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
}

@classmethod
def create_provider(cls, provider_key, config):
"""Dynamically import and create an instance of a provider based on the provider key."""
if not isinstance(provider_key, ProviderNames):
raise ValueError(
f"Provider {provider_key} is not a valid ProviderNames enum"
)

module_name, class_name = cls._get_provider_info(provider_key)
if not module_name:
raise ValueError(f"Provider {provider_key.value} is not supported")

# Lazily load the module
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f"Could not import module {module_name}: {str(e)}")
rohitprasad15 marked this conversation as resolved.
Show resolved Hide resolved

# Instantiate the provider class
provider_class = getattr(module, class_name)
return provider_class(**config)

@classmethod
def _get_provider_info(cls, provider_key):
"""Return the module name and class name for a given provider key."""
return cls._provider_info.get(provider_key, (None, None))
40 changes: 40 additions & 0 deletions aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import anthropic
from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse

# Define a constant for the default max_tokens value
DEFAULT_MAX_TOKENS = 4096


class AnthropicProvider(Provider):
def __init__(self, **config):
"""
Initialize the Anthropic provider with the given configuration.
Pass the entire configuration dictionary to the Anthropic client constructor.
"""

self.client = anthropic.Anthropic(**config)

def chat_completions_create(self, model, messages, **kwargs):
# Check if the fist message is a system message
if messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages = messages[1:]
else:
system_message = None

# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS)
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS

return self.normalize_response(
self.client.messages.create(
model=model, system=system_message, messages=messages, **kwargs
)
)

def normalize_response(self, response):
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response.content[0].text
return normalized_response
Loading
Loading