Skip to content

Commit

Permalink
Update GoogleProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
standsleeping committed Sep 14, 2024
1 parent ac79070 commit a9915d0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
2 changes: 2 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ProviderNames(str, Enum):
ANTHROPIC = "anthropic"
AZURE = "azure"
GROQ = "groq"
GOOGLE = "google"


class ProviderFactory:
Expand All @@ -40,6 +41,7 @@ class ProviderFactory:
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"),
}

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .google_interface import GoogleInterface
from .google_provider import GoogleProvider
34 changes: 26 additions & 8 deletions aisuite/providers/google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,40 @@
from aisuite.framework import ProviderInterface, ChatCompletionResponse


class GoogleInterface(ProviderInterface):
DEFAULT_TEMPERATURE = 0.7


class GoogleProvider(ProviderInterface):
"""Implements the ProviderInterface for interacting with Google's Vertex AI."""

def __init__(self):
def __init__(self, **config):
"""Set up the Google AI client with a project ID."""
import vertexai

project_id = os.getenv("GOOGLE_PROJECT_ID")
location = os.getenv("GOOGLE_REGION")
app_creds_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
"""Set up the Google AI client with a project ID."""
self.project_id = config.get("project_id") or os.getenv("GOOGLE_PROJECT_ID")
self.location = config.get("region") or os.getenv("GOOGLE_REGION")
self.app_creds_path = config.get("application_credentials") or os.getenv(
"GOOGLE_APPLICATION_CREDENTIALS"
)

if not project_id or not location or not app_creds_path:
if not self.project_id or not self.location or not self.app_creds_path:
raise EnvironmentError(
"Missing one or more required Google environment variables: "
"GOOGLE_PROJECT_ID, GOOGLE_REGION, GOOGLE_APPLICATION_CREDENTIALS. "
"Please refer to the setup guide: /guides/google.md."
)

vertexai.init(project=project_id, location=location)
vertexai.init(project=self.project_id, location=self.location)

def chat_completion_create(self, messages=None, model=None, temperature=0):
def chat_completions_create(self, model, messages, **kwargs):
"""Request chat completions from the Google AI API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
kwargs (dict): Optional arguments for the Google AI API.
Returns:
-------
Expand All @@ -39,19 +46,30 @@ def chat_completion_create(self, messages=None, model=None, temperature=0):
"""
from vertexai.generative_models import GenerativeModel, GenerationConfig

# Set the temperature if provided, otherwise use the default
temperature = kwargs.get("temperature", DEFAULT_TEMPERATURE)

# Transform the roles in the messages
transformed_messages = self.transform_roles(messages)

# Convert the messages to the format expected Google
final_message_history = self.convert_openai_to_vertex_ai(
transformed_messages[:-1]
)

# Get the last message from the transformed messages
last_message = transformed_messages[-1]["content"]

# Create the GenerativeModel with the specified model and generation configuration
model = GenerativeModel(
model, generation_config=GenerationConfig(temperature=temperature)
)

# Start a chat with the GenerativeModel and send the last message
chat = model.start_chat(history=final_message_history)
response = chat.send_message(last_message)

# Convert the response to the format expected by the OpenAI API
return self.convert_response_to_openai_format(response)

def convert_openai_to_vertex_ai(self, messages):
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/test_google_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import patch, MagicMock
from aisuite.providers.google_interface import GoogleInterface
from aisuite.providers.google_provider import GoogleProvider
from vertexai.generative_models import Content, Part


Expand All @@ -16,7 +16,7 @@ def test_missing_env_vars():
"""Test that an error is raised if required environment variables are missing."""
with patch.dict("os.environ", {}, clear=True):
with pytest.raises(EnvironmentError) as exc_info:
GoogleInterface()
GoogleProvider()
assert "Missing one or more required Google environment variables" in str(
exc_info.value
)
Expand All @@ -30,7 +30,7 @@ def test_vertex_interface():
selected_model = "our-favorite-model"
response_text_content = "mocked-text-response-from-model"

interface = GoogleInterface()
interface = GoogleProvider()
mock_response = MagicMock()
mock_response.candidates = [MagicMock()]
mock_response.candidates[0].content.parts[0].text = response_text_content
Expand All @@ -42,7 +42,7 @@ def test_vertex_interface():
mock_model.start_chat.return_value = mock_chat
mock_chat.send_message.return_value = mock_response

response = interface.chat_completion_create(
response = interface.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=0.7,
Expand All @@ -68,7 +68,7 @@ def test_vertex_interface():


def test_convert_openai_to_vertex_ai():
interface = GoogleInterface()
interface = GoogleProvider()
message_history = [{"role": "user", "content": "Hello!"}]
result = interface.convert_openai_to_vertex_ai(message_history)
assert isinstance(result[0], Content)
Expand All @@ -79,7 +79,7 @@ def test_convert_openai_to_vertex_ai():


def test_transform_roles():
interface = GoogleInterface()
interface = GoogleProvider()

messages = [
{"role": "system", "content": "Google: system message = 1st user message."},
Expand Down

0 comments on commit a9915d0

Please sign in to comment.