diff --git a/aisuite/provider.py b/aisuite/provider.py index 46e8f493..3bf81062 100644 --- a/aisuite/provider.py +++ b/aisuite/provider.py @@ -23,6 +23,7 @@ class ProviderNames(str, Enum): ANTHROPIC = "anthropic" AZURE = "azure" GROQ = "groq" + GOOGLE = "google" class ProviderFactory: @@ -40,6 +41,7 @@ class ProviderFactory: ), ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"), ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"), + ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"), } @classmethod diff --git a/aisuite/providers/__init__.py b/aisuite/providers/__init__.py index 816f790e..7f25a128 100644 --- a/aisuite/providers/__init__.py +++ b/aisuite/providers/__init__.py @@ -10,4 +10,4 @@ from .openai_interface import OpenAIInterface from .replicate_interface import ReplicateInterface from .together_interface import TogetherInterface -from .google_interface import GoogleInterface +from .google_provider import GoogleProvider diff --git a/aisuite/providers/google_provider.py b/aisuite/providers/google_provider.py index 88aa3b68..d2a9ddf9 100644 --- a/aisuite/providers/google_provider.py +++ b/aisuite/providers/google_provider.py @@ -4,33 +4,40 @@ from aisuite.framework import ProviderInterface, ChatCompletionResponse -class GoogleInterface(ProviderInterface): +DEFAULT_TEMPERATURE = 0.7 + + +class GoogleProvider(ProviderInterface): """Implements the ProviderInterface for interacting with Google's Vertex AI.""" - def __init__(self): + def __init__(self, **config): """Set up the Google AI client with a project ID.""" import vertexai - project_id = os.getenv("GOOGLE_PROJECT_ID") - location = os.getenv("GOOGLE_REGION") - app_creds_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") + """Set up the Google AI client with a project ID.""" + self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID") + self.location = config.get("region") or os.getenv("GOOGLE_REGION") + self.app_creds_path = config.get("application_credentials") or os.getenv( + "GOOGLE_APPLICATION_CREDENTIALS" + ) - if not project_id or not location or not app_creds_path: + if not self.project_id or not self.location or not self.app_creds_path: raise EnvironmentError( "Missing one or more required Google environment variables: " "GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. " "Please refer to the setup guide: /guides/google.md." ) - vertexai.init(project=project_id, location=location) + vertexai.init(project=self.project_id, location=self.location) - def chat_completion_create(self, messages=None, model=None, temperature=0): + def chat_completions_create(self, model, messages, **kwargs): """Request chat completions from the Google AI API. Args: ---- model (str): Identifies the specific provider/model to use. messages (list of dict): A list of message objects in chat history. + kwargs (dict): Optional arguments for the Google AI API. Returns: ------- @@ -39,19 +46,30 @@ def chat_completion_create(self, messages=None, model=None, temperature=0): """ from vertexai.generative_models import GenerativeModel, GenerationConfig + # Set the temperature if provided, otherwise use the default + temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE) + + # Transform the roles in the messages transformed_messages = self.transform_roles(messages) + # Convert the messages to the format expected Google final_message_history = self.convert_openai_to_vertex_ai( transformed_messages[:-1] ) + + # Get the last message from the transformed messages last_message = transformed_messages[-1]["content"] + # Create the GenerativeModel with the specified model and generation configuration model = GenerativeModel( model, generation_config=GenerationConfig(temperature=temperature) ) + # Start a chat with the GenerativeModel and send the last message chat = model.start_chat(history=final_message_history) response = chat.send_message(last_message) + + # Convert the response to the format expected by the OpenAI API return self.convert_response_to_openai_format(response) def convert_openai_to_vertex_ai(self, messages): diff --git a/tests/providers/test_google_provider.py b/tests/providers/test_google_provider.py index da6a8237..c6b5bea6 100644 --- a/tests/providers/test_google_provider.py +++ b/tests/providers/test_google_provider.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import patch, MagicMock -from aisuite.providers.google_interface import GoogleInterface +from aisuite.providers.google_provider import GoogleProvider from vertexai.generative_models import Content, Part @@ -16,7 +16,7 @@ def test_missing_env_vars(): """Test that an error is raised if required environment variables are missing.""" with patch.dict("os.environ", {}, clear=True): with pytest.raises(EnvironmentError) as exc_info: - GoogleInterface() + GoogleProvider() assert "Missing one or more required Google environment variables" in str( exc_info.value ) @@ -30,7 +30,7 @@ def test_vertex_interface(): selected_model = "our-favorite-model" response_text_content = "mocked-text-response-from-model" - interface = GoogleInterface() + interface = GoogleProvider() mock_response = MagicMock() mock_response.candidates = [MagicMock()] mock_response.candidates[0].content.parts[0].text = response_text_content @@ -42,7 +42,7 @@ def test_vertex_interface(): mock_model.start_chat.return_value = mock_chat mock_chat.send_message.return_value = mock_response - response = interface.chat_completion_create( + response = interface.chat_completions_create( messages=message_history, model=selected_model, temperature=0.7, @@ -68,7 +68,7 @@ def test_vertex_interface(): def test_convert_openai_to_vertex_ai(): - interface = GoogleInterface() + interface = GoogleProvider() message_history = [{"role": "user", "content": "Hello!"}] result = interface.convert_openai_to_vertex_ai(message_history) assert isinstance(result[0], Content) @@ -79,7 +79,7 @@ def test_convert_openai_to_vertex_ai(): def test_transform_roles(): - interface = GoogleInterface() + interface = GoogleProvider() messages = [ {"role": "system", "content": "Google: system message = 1st user message."},