Skip to content

Commit

Permalink
Merge pull request #29 from andrewyng/update-google-provider
Browse files Browse the repository at this point in the history
Update GoogleProvider
  • Loading branch information
standsleeping authored Sep 18, 2024
2 parents 8cda46c + dab5be0 commit 120ed16
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 22 deletions.
2 changes: 2 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ProviderNames(str, Enum):
AWS_BEDROCK = "aws-bedrock"
AZURE = "azure"
GROQ = "groq"
GOOGLE = "google"
MISTRAL = "mistral"
OPENAI = "openai"

Expand All @@ -40,6 +41,7 @@ class ProviderFactory:
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"),
ProviderNames.MISTRAL: (
"aisuite.providers.mistral_provider",
"MistralProvider",
Expand Down
1 change: 0 additions & 1 deletion aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
from .ollama_interface import OllamaInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .google_interface import GoogleInterface
Original file line number Diff line number Diff line change
@@ -1,58 +1,76 @@
"""The interface to Google's Vertex AI."""

import os

import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig

from aisuite.framework import ProviderInterface, ChatCompletionResponse


class GoogleInterface(ProviderInterface):
DEFAULT_TEMPERATURE = 0.7


class GoogleProvider(ProviderInterface):
"""Implements the ProviderInterface for interacting with Google's Vertex AI."""

def __init__(self):
def __init__(self, **config):
"""Set up the Google AI client with a project ID."""
import vertexai

project_id = os.getenv("GOOGLE_PROJECT_ID")
location = os.getenv("GOOGLE_REGION")
app_creds_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID")
self.location = config.get("region") or os.getenv("GOOGLE_REGION")
self.app_creds_path = config.get("application_credentials") or os.getenv(
"GOOGLE_APPLICATION_CREDENTIALS"
)

if not project_id or not location or not app_creds_path:
if not self.project_id or not self.location or not self.app_creds_path:
raise EnvironmentError(
"Missing one or more required Google environment variables: "
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. "
"Please refer to the setup guide: /guides/google.md."
)

vertexai.init(project=project_id, location=location)
vertexai.init(project=self.project_id, location=self.location)

def chat_completion_create(self, messages=None, model=None, temperature=0):
def chat_completions_create(self, model, messages, **kwargs):
"""Request chat completions from the Google AI API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
kwargs (dict): Optional arguments for the Google AI API.
Returns:
-------
The ChatCompletionResponse with the completion result.
"""
from vertexai.generative_models import GenerativeModel, GenerationConfig

# Set the temperature if provided, otherwise use the default
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE)

# Transform the roles in the messages
transformed_messages = self.transform_roles(messages)

# Convert the messages to the format expected Google
final_message_history = self.convert_openai_to_vertex_ai(
transformed_messages[:-1]
)

# Get the last message from the transformed messages
last_message = transformed_messages[-1]["content"]

# Create the GenerativeModel with the specified model and generation configuration
model = GenerativeModel(
model, generation_config=GenerationConfig(temperature=temperature)
)

# Start a chat with the GenerativeModel and send the last message
chat = model.start_chat(history=final_message_history)
response = chat.send_message(last_message)
return self.convert_response_to_openai_format(response)

# Convert the response to the format expected by the OpenAI API
return self.normalize_response(response)

def convert_openai_to_vertex_ai(self, messages):
"""Convert OpenAI messages to Google AI messages."""
Expand All @@ -78,8 +96,8 @@ def transform_roles(self, messages):
message["role"] = role
return messages

def convert_response_to_openai_format(self, response):
"""Convert Google AI response to OpenAI's ChatCompletionResponse format."""
def normalize_response(self, response):
"""Normalize the response from Google AI to match OpenAI's response format."""
openai_response = ChatCompletionResponse()
openai_response.choices[0].message.content = (
response.candidates[0].content.parts[0].text
Expand Down
16 changes: 16 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class TestClient(unittest.TestCase):
@patch(
"aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create"
)
@patch("aisuite.providers.google_provider.GoogleProvider.chat_completions_create")
def test_client_chat_completions(
self,
mock_google,
mock_anthropic,
mock_azure,
mock_bedrock,
Expand All @@ -31,6 +33,7 @@ def test_client_chat_completions(
mock_anthropic.return_value = "Anthropic Response"
mock_groq.return_value = "Groq Response"
mock_mistral.return_value = "Mistral Response"
mock_google.return_value = "Google Response"

# Provider configurations
provider_configs = {
Expand All @@ -50,6 +53,11 @@ def test_client_chat_completions(
ProviderNames.MISTRAL: {
"api_key": "mistral-api-key",
},
ProviderNames.GOOGLE: {
"project_id": "test_google_project_id",
"region": "us-west4",
"application_credentials": "test_google_application_credentials",
},
}

# Initialize the client
Expand Down Expand Up @@ -104,6 +112,14 @@ def test_client_chat_completions(
self.assertEqual(mistral_response, "Mistral Response")
mock_mistral.assert_called_once()

# Test Google model
google_model = ProviderNames.GOOGLE + ":" + "google-model"
google_response = client.chat.completions.create(
google_model, messages=messages
)
self.assertEqual(google_response, "Google Response")
mock_google.assert_called_once()

# Test that new instances of Completion are not created each time we make an inference call.
compl_instance = client.chat.completions
next_compl_instance = client.chat.completions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import patch, MagicMock
from aisuite.providers.google_interface import GoogleInterface
from aisuite.providers.google_provider import GoogleProvider
from vertexai.generative_models import Content, Part


Expand All @@ -16,7 +16,7 @@ def test_missing_env_vars():
"""Test that an error is raised if required environment variables are missing."""
with patch.dict("os.environ", {}, clear=True):
with pytest.raises(EnvironmentError) as exc_info:
GoogleInterface()
GoogleProvider()
assert "Missing one or more required Google environment variables" in str(
exc_info.value
)
Expand All @@ -30,19 +30,21 @@ def test_vertex_interface():
selected_model = "our-favorite-model"
response_text_content = "mocked-text-response-from-model"

interface = GoogleInterface()
interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content

with patch("vertexai.generative_models.GenerativeModel") as mock_generative_model:
with patch(
"aisuite.providers.google_provider.GenerativeModel"
) as mock_generative_model:
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_chat = MagicMock()
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response

response = interface.chat_completion_create(
response = interface.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=0.7,
Expand All @@ -68,7 +70,7 @@ def test_vertex_interface():


def test_convert_openai_to_vertex_ai():
interface = GoogleInterface()
interface = GoogleProvider()
message_history = [{"role": "user", "content": "Hello!"}]
result = interface.convert_openai_to_vertex_ai(message_history)
assert isinstance(result[0], Content)
Expand All @@ -79,7 +81,7 @@ def test_convert_openai_to_vertex_ai():


def test_transform_roles():
interface = GoogleInterface()
interface = GoogleProvider()

messages = [
{"role": "system", "content": "Google: system message = 1st user message."},
Expand Down

0 comments on commit 120ed16

Please sign in to comment.