Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes client.py #114

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 72 additions & 57 deletions aisuite/client.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,132 @@
from typing import Dict, Any, Optional
from .provider import ProviderFactory


class Client:
def __init__(self, provider_configs: dict = {}):
def __init__(self, provider_configs: Optional[Dict[str, Dict[str, Any]]] = None):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.
Initialize the client with provider configurations using the ProviderFactory.

Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a provider string (e.g., "google" or "aws-bedrock"),
and the value should be a dictionary of configuration options for that provider.
For example:
{
"openai": {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"aws_region": "us-west-2"
provider_configs (Optional[Dict[str, Dict[str, Any]]]):
A dictionary containing provider configurations.
Example:
{
"openai": {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"aws_region": "us-west-2"
}
}
}
"""
self.provider_configs = provider_configs or {}
self.providers = {}
self.provider_configs = provider_configs
self._chat = None
self._initialize_providers()

def _initialize_providers(self):
"""Helper method to initialize or update providers."""
"""Initialize providers using the ProviderFactory."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self._validate_provider_key(provider_key)
self.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key corresponds to a supported provider.
"""
@staticmethod
def _validate_provider_key(provider_key: str):
"""Validate if the provider key corresponds to a supported provider."""
supported_providers = ProviderFactory.get_supported_providers()

if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
"Ensure the model string is formatted as 'provider:model'."
)

return provider_key

def configure(self, provider_configs: dict = None):
def configure(self, provider_configs: Optional[Dict[str, Dict[str, Any]]] = None):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return
Configure or update provider configurations.

self.provider_configs.update(provider_configs)
self._initialize_providers() # NOTE: This will override existing provider instances.
Args:
provider_configs (Optional[Dict[str, Dict[str, Any]]]): New provider configurations.
"""
if provider_configs:
self.provider_configs.update(provider_configs)
self._initialize_providers()

@property
def chat(self):
"""Return the chat API interface."""
"""Return the chat API interface, initializing it lazily."""
if not self._chat:
self._chat = Chat(self)
return self._chat


class Chat:
def __init__(self, client: "Client"):
def __init__(self, client: Client):
"""Initialize Chat with a reference to the Client."""
self.client = client
self._completions = Completions(self.client)
self._completions = None

@property
def completions(self):
"""Return the completions interface."""
"""Return the completions interface, initializing it lazily."""
if not self._completions:
self._completions = Completions(self.client)
return self._completions


class Completions:
def __init__(self, client: "Client"):
def __init__(self, client: Client):
"""Initialize Completions with a reference to the Client."""
self.client = client

def create(self, model: str, messages: list, **kwargs):
"""
Create chat completion based on the model, messages, and any extra arguments.

Args:
model (str): Model identifier in the format 'provider:model'.
messages (list): List of message objects.
**kwargs: Additional arguments for the provider's chat completion.

Returns:
Response from the provider's chat completion.
"""
# Check that correct format is used
provider_key, model_name = self._extract_provider_and_model(model)

# Ensure provider is initialized
if provider_key not in self.client.providers:
self.client.providers[provider_key] = self._initialize_provider(provider_key)

provider = self.client.providers[provider_key]
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the chat completion to the provider
return provider.chat_completions_create(model_name, messages, **kwargs)

def _extract_provider_and_model(self, model: str):
"""Extract provider and model from the model string."""
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
f"Invalid model format. Expected 'provider:model', got '{model}'."
)

# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)
self._validate_provider_key(provider_key)
return provider_key, model_name

# Validate if the provider is supported
def _initialize_provider(self, provider_key: str):
"""Initialize a provider if not already done."""
config = self.client.provider_configs.get(provider_key, {})
return ProviderFactory.create_provider(provider_key, config)

@staticmethod
def _validate_provider_key(provider_key: str):
"""Validate if the provider key corresponds to a supported provider."""
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
"Ensure the model string is formatted as 'provider:model'."
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the chat completion to the correct provider's implementation
return provider.chat_completions_create(model_name, messages, **kwargs)