forked from yomorun/yomo
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(llm-bridge): add vertexai provider (yomorun#961)
Co-authored-by: venjiang <[email protected]>
- Loading branch information
Showing
8 changed files
with
482 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,3 +48,9 @@ bridge: | |
xai: | ||
api_key: | ||
model: | ||
|
||
vertexai: | ||
project_id: | ||
location: | ||
model: | ||
credentials_file: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Package vertexai is used to provide the vertexai service | ||
package vertexai | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"log" | ||
"time" | ||
|
||
"cloud.google.com/go/vertexai/genai" | ||
openai "github.com/sashabaranov/go-openai" | ||
"github.com/yomorun/yomo/core/metadata" | ||
"github.com/yomorun/yomo/pkg/bridge/ai/provider" | ||
"github.com/yomorun/yomo/pkg/id" | ||
"google.golang.org/api/iterator" | ||
"google.golang.org/api/option" | ||
) | ||
|
||
// Provider is the provider for google vertexai. | ||
type Provider struct { | ||
model string | ||
client *genai.Client | ||
} | ||
|
||
var _ provider.LLMProvider = &Provider{} | ||
|
||
// NewProvider creates a new vertexai provider. | ||
func NewProvider(projectID, location, model, credentialsFile string) *Provider { | ||
client, err := genai.NewClient(context.Background(), projectID, location, option.WithCredentialsFile(credentialsFile)) | ||
if err != nil { | ||
log.Fatal("new vertexai client: ", err) | ||
} | ||
if model == "" { | ||
model = "gemini-1.5-pro-002" | ||
} | ||
|
||
return &Provider{ | ||
model: model, | ||
client: client, | ||
} | ||
} | ||
|
||
func (p *Provider) generativeModel(req openai.ChatCompletionRequest) *genai.GenerativeModel { | ||
model := p.client.GenerativeModel(p.model) | ||
|
||
model.SetTemperature(req.Temperature) | ||
model.SetTopP(req.TopP) | ||
if req.MaxTokens > 0 { | ||
model.SetMaxOutputTokens(int32(req.MaxTokens)) | ||
} | ||
if len(req.Stop) != 0 { | ||
model.StopSequences = req.Stop | ||
} | ||
|
||
return model | ||
} | ||
|
||
// GetChatCompletions implements provider.LLMProvider. | ||
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (openai.ChatCompletionResponse, error) { | ||
model := p.generativeModel(req) | ||
|
||
chat := model.StartChat() | ||
|
||
parts := convertPart(chat, req, model, md) | ||
|
||
resp, err := chat.SendMessage(ctx, parts...) | ||
if err != nil { | ||
return openai.ChatCompletionResponse{}, err | ||
} | ||
|
||
return convertToResponse(resp, p.model), nil | ||
} | ||
|
||
// GetChatCompletionsStream implements provider.LLMProvider. | ||
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (provider.ResponseRecver, error) { | ||
model := p.generativeModel(req) | ||
|
||
chat := model.StartChat() | ||
|
||
parts := convertPart(chat, req, model, md) | ||
|
||
resp := chat.SendMessageStream(ctx, parts...) | ||
|
||
includeUsage := false | ||
if req.StreamOptions != nil && req.StreamOptions.IncludeUsage { | ||
includeUsage = true | ||
} | ||
|
||
recver := &recver{ | ||
model: p.model, | ||
underlying: resp, | ||
includeUsage: includeUsage, | ||
} | ||
|
||
return recver, nil | ||
} | ||
|
||
// Name implements provider.LLMProvider. | ||
func (p *Provider) Name() string { | ||
return "vertexai" | ||
} | ||
|
||
type recver struct { | ||
done bool | ||
id string | ||
includeUsage bool | ||
usage *openai.Usage | ||
model string | ||
underlying *genai.GenerateContentResponseIterator | ||
} | ||
|
||
// Recv implements provider.ResponseRecver. | ||
func (r *recver) Recv() (response openai.ChatCompletionStreamResponse, err error) { | ||
if r.done { | ||
return openai.ChatCompletionStreamResponse{}, io.EOF | ||
} | ||
if r.id == "" { | ||
r.id = "chatcmpl-" + id.New(29) | ||
} | ||
if r.usage == nil { | ||
r.usage = &openai.Usage{} | ||
} | ||
resp, err := r.underlying.Next() | ||
if err == iterator.Done { | ||
r.usage.TotalTokens = r.usage.PromptTokens + r.usage.CompletionTokens | ||
usageResp := openai.ChatCompletionStreamResponse{ | ||
ID: r.id, | ||
Model: r.model, | ||
Object: "chat.completion.chunk", | ||
Created: time.Now().Unix(), | ||
Usage: r.usage, | ||
Choices: make([]openai.ChatCompletionStreamChoice, 0), | ||
} | ||
r.done = true | ||
return usageResp, nil | ||
} | ||
if err != nil { | ||
return openai.ChatCompletionStreamResponse{}, err | ||
} | ||
|
||
return convertToStreamResponse(r.id, resp, r.model, r.usage), nil | ||
} |
Oops, something went wrong.