Skip to content

Commit

Permalink
Add support for Google VertexAI language model provider
Browse files Browse the repository at this point in the history
  • Loading branch information
codingbandit committed Feb 5, 2025
1 parent 65af46d commit f8c41b0
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ public static class APIEndpointProviders
public const string OPENAI = "openai";

/// <summary>
/// Bedrock
/// Amazon Bedrock
/// </summary>
public const string BEDROCK = "bedrock";

/// <summary>
/// Google VertexAI
/// </summary>
public const string VERTEXAI = "vertexai";

/// <summary>
/// All providers.
/// </summary>
public readonly static string[] All = [MICROSOFT, OPENAI, BEDROCK];
public readonly static string[] All = [MICROSOFT, OPENAI, BEDROCK, VERTEXAI];
}
}
11 changes: 6 additions & 5 deletions src/python/LangChainAPI/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ azure-search-documents==11.6.0b9
azure-storage-blob==12.21.0
email-validator==2.2.0
fastapi==0.115.6
langchain==0.3.7
langchain-aws==0.2.7
langchain-experimental==0.3.3
langchain==0.3.17
langchain-aws==0.2.12
langchain-experimental==0.3.4
langchain-google-vertexai==2.0.13
langchain-openai==0.3.3
langchain-azure-dynamic-sessions==0.2.0
langchain-deepseek-official==0.1.0
langgraph==0.2.53
openai==1.60.2
openai==1.61.0
pandas==2.2.2
pyarrow==18.1.0
pydantic==2.8.2
pydantic==2.10.6
pylint==3.2.6
pyodbc==5.2.0
sqlalchemy==2.0.36
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import boto3
import json
from abc import abstractmethod
from typing import List
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

from google.oauth2 import service_account
from langchain_core.language_models import BaseLanguageModel
from langchain_aws import ChatBedrockConverse
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI
from openai import AsyncAzureOpenAI as async_aoi
from foundationallm.config import Configuration, UserIdentity
Expand Down Expand Up @@ -397,6 +399,26 @@ def _get_language_model(self, override_operation_type: OperationTypes = None) ->
aws_access_key_id = access_key,
aws_secret_access_key = secret_key
)
case LanguageModelProvider.VERTEXAI:
# Only supports service account authentication via JSON credentials stored in key vault.
# Uses the authentication parameter: service_account_credentials to get the application configuration key for this value.
try:
service_account_credentials_definition = json.loads(self.config.get_value(self.api_endpoint.authentication_parameters.get('service_account_credentials')))
except Exception as e:
raise LangChainException(f"Failed to retrieve service account credentials: {str(e)}", 500)

if not service_account_credentials_definition:
raise LangChainException("Service account credentials are missing from the configuration settings.", 400)

service_account_credentials = service_account.Credentials.from_service_account_info(service_account_credentials_definition)
language_model = ChatVertexAI(
model=self.ai_model.deployment_name,
temperature=0,
max_tokens=None,
max_retries=6,
stop=None,
credentials=service_account_credentials
)

# Set model parameters.
for key, value in self.ai_model.model_parameters.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
if isinstance(item, ContentArtifact):
content_artifacts.append(item)

final_message = response["messages"][-1]
final_message = response["messages"][-1]
response_content = OpenAITextMessageContentItem(
value = final_message.content,
agent_capability_category = AgentCapabilityCategories.FOUNDATIONALLM_KNOWLEDGE_MANAGEMENT
Expand Down Expand Up @@ -599,27 +599,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C

retvalue = None

if self.api_endpoint.provider == LanguageModelProvider.BEDROCK:
if self.has_retriever:
completion = chain.invoke(request.user_prompt)
else:
completion = await chain.ainvoke(request.user_prompt)
response_content = OpenAITextMessageContentItem(
value = completion.content,
agent_capability_category = AgentCapabilityCategories.FOUNDATIONALLM_KNOWLEDGE_MANAGEMENT
)
retvalue = CompletionResponse(
operation_id = request.operation_id,
content = [response_content],
user_prompt = request.user_prompt,
user_prompt_rewrite = request.user_prompt_rewrite,
full_prompt = self.full_prompt.text,
completion_tokens = completion.usage_metadata["output_tokens"] + image_analysis_token_usage.completion_tokens,
prompt_tokens = completion.usage_metadata["input_tokens"] + image_analysis_token_usage.prompt_tokens,
total_tokens = completion.usage_metadata["total_tokens"] + image_analysis_token_usage.total_tokens,
total_cost = 0
)
else:
if self.api_endpoint.provider == LanguageModelProvider.MICROSOFT or self.api_endpoint.provider == LanguageModelProvider.OPENAI:
# OpenAI compatible models
with get_openai_callback() as cb:
# add output parser to openai callback
Expand Down Expand Up @@ -647,6 +627,26 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
)
except Exception as e:
raise LangChainException(f"An unexpected exception occurred when executing the completion request: {str(e)}", 500)
else:
if self.has_retriever:
completion = chain.invoke(request.user_prompt)
else:
completion = await chain.ainvoke(request.user_prompt)
response_content = OpenAITextMessageContentItem(
value = completion.content,
agent_capability_category = AgentCapabilityCategories.FOUNDATIONALLM_KNOWLEDGE_MANAGEMENT
)
retvalue = CompletionResponse(
operation_id = request.operation_id,
content = [response_content],
user_prompt = request.user_prompt,
user_prompt_rewrite = request.user_prompt_rewrite,
full_prompt = self.full_prompt.text,
completion_tokens = completion.usage_metadata["output_tokens"] + image_analysis_token_usage.completion_tokens,
prompt_tokens = completion.usage_metadata["input_tokens"] + image_analysis_token_usage.prompt_tokens,
total_tokens = completion.usage_metadata["total_tokens"] + image_analysis_token_usage.total_tokens,
total_cost = 0
)

if isinstance(retriever, ContentArtifactRetrievalBase):
retvalue.content_artifacts = retriever.get_document_content_artifacts() or []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class LanguageModelProvider(str, Enum):
MICROSOFT = "microsoft"
OPENAI = "openai"
BEDROCK = "bedrock"
VERTEXAI = "vertexai"
15 changes: 8 additions & 7 deletions src/python/PythonSDK/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ azure-monitor-opentelemetry==1.6.2
azure-monitor-opentelemetry-exporter==1.0.0b28
azure-search-documents==11.6.0b9
azure-storage-blob==12.21.0
boto3==1.35.67
langchain==0.3.7
langchain-aws==0.2.7
langchain-experimental==0.3.3
langchain-openai==0.2.9
boto3==1.36.13
langchain==0.3.17
langchain-aws==0.2.12
langchain-experimental==0.3.4
langchain-google-vertexai==2.0.13
langchain-openai==0.3.3
langchain-azure-dynamic-sessions==0.2.0
langgraph==0.2.53
openai==1.55.3
openai==1.61.0
opentelemetry-api==1.27.0
opentelemetry-sdk==1.27.0
pandas==2.2.2
pydantic==2.8.2
pydantic==2.10.6
unidecode==1.3.8
wikipedia==1.4.0

0 comments on commit f8c41b0

Please sign in to comment.