Skip to content

Commit

Permalink
feat(llm-bridge): add vertexai provider (yomorun#961)
Browse files Browse the repository at this point in the history
Co-authored-by: venjiang <[email protected]>
  • Loading branch information
woorui and venjiang authored Dec 14, 2024
1 parent c6e3be9 commit e5cd4ff
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 5 deletions.
8 changes: 8 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/vertexai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/xai"
)

Expand Down Expand Up @@ -162,6 +163,13 @@ func registerAIProvider(aiConfig *ai.Config) error {
providerpkg.RegisterProvider(anthropic.NewProvider(provider["api_key"], provider["model"]))
case "xai":
providerpkg.RegisterProvider(xai.NewProvider(provider["api_key"], provider["model"]))
case "vertexai":
providerpkg.RegisterProvider(vertexai.NewProvider(
provider["project_id"],
provider["location"],
provider["model"],
provider["credentials_file"],
))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
Expand Down
6 changes: 6 additions & 0 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ bridge:
xai:
api_key:
model:

vertexai:
project_id:
location:
model:
credentials_file:
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/yomorun/yomo
go 1.21

require (
cloud.google.com/go/vertexai v0.13.2
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4
github.com/briandowns/spinner v1.23.0
github.com/bytecodealliance/wasmtime-go/v9 v9.0.0
Expand Down Expand Up @@ -41,10 +42,12 @@ require (
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/ai v0.8.0 // indirect
cloud.google.com/go/aiplatform v1.68.0 // indirect
cloud.google.com/go/auth v0.10.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.5 // indirect
cloud.google.com/go/compute/metadata v0.5.2 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
cloud.google.com/go/iam v1.2.2 // indirect
cloud.google.com/go/longrunning v0.6.2 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
Expand Down Expand Up @@ -108,6 +111,7 @@ require (
golang.org/x/term v0.26.0 // indirect
golang.org/x/text v0.20.0 // indirect
golang.org/x/time v0.8.0 // indirect
google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 // indirect
google.golang.org/grpc v1.67.1 // indirect
Expand Down
12 changes: 10 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U=
cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME=
cloud.google.com/go/auth v0.10.2 h1:oKF7rgBfSHdp/kuhXtqU/tNDr0mZqhYbEh+6SiqzkKo=
cloud.google.com/go/auth v0.10.2/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI=
cloud.google.com/go/auth/oauth2adapt v0.2.5 h1:2p29+dePqsCHPP1bqDJcKj4qxRyYCcbzKpFyKGt3MTk=
cloud.google.com/go/auth/oauth2adapt v0.2.5/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo=
cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k=
cloud.google.com/go/longrunning v0.6.0 h1:mM1ZmaNsQsnb+5n1DNPeL0KwQd9jQRqSqSDEkBZr+aI=
cloud.google.com/go/longrunning v0.6.0/go.mod h1:uHzSZqW89h7/pasCWNYdUpwGz3PcVWhrWupreVPYLts=
cloud.google.com/go/iam v1.2.2 h1:ozUSofHUGf/F4tCNy/mu9tHLTaxZFLOUiKzjcgWHGIA=
cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY=
cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc=
cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
cloud.google.com/go/vertexai v0.13.2 h1:dOnvkMDZy3GdKAz8Isd2d6KV3jQpk6CKvYao1SIupuk=
cloud.google.com/go/vertexai v0.13.2/go.mod h1:+nmz1z8AeYILA5QM2yii3CED1PqGknZH1CUNDVatIg4=
dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU=
dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU=
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4=
Expand Down Expand Up @@ -428,6 +434,8 @@ google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk
google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 h1:KJjNNclfpIkVqrZlTWcgOOaVQ00LdBnoEaRfkUx760s=
google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:mt9/MofW7AWQ+Gy179ChOnvmJatV8YHUmrcedo9CIFI=
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g=
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 h1:XVhgTWWV3kGQlwJHR3upFWZeTsei6Oks1apkZSeonIE=
Expand Down
3 changes: 1 addition & 2 deletions pkg/bridge/ai/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ func TestProviders(t *testing.T) {
assert.Equal(t, p1, p)
})
t.Run("name is empty", func(t *testing.T) {
p, err := GetProvider("")
_, err := GetProvider("")
assert.NoError(t, err)
assert.Equal(t, p1, p)
})
t.Run("not found", func(t *testing.T) {
p, err := GetProvider("name-x")
Expand Down
142 changes: 142 additions & 0 deletions pkg/bridge/ai/provider/vertexai/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Package vertexai is used to provide the vertexai service
package vertexai

import (
"context"
"io"
"log"
"time"

"cloud.google.com/go/vertexai/genai"
openai "github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/id"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)

// Provider is the provider for google vertexai.
type Provider struct {
model string
client *genai.Client
}

var _ provider.LLMProvider = &Provider{}

// NewProvider creates a new vertexai provider.
func NewProvider(projectID, location, model, credentialsFile string) *Provider {
client, err := genai.NewClient(context.Background(), projectID, location, option.WithCredentialsFile(credentialsFile))
if err != nil {
log.Fatal("new vertexai client: ", err)
}
if model == "" {
model = "gemini-1.5-pro-002"
}

return &Provider{
model: model,
client: client,
}
}

func (p *Provider) generativeModel(req openai.ChatCompletionRequest) *genai.GenerativeModel {
model := p.client.GenerativeModel(p.model)

model.SetTemperature(req.Temperature)
model.SetTopP(req.TopP)
if req.MaxTokens > 0 {
model.SetMaxOutputTokens(int32(req.MaxTokens))
}
if len(req.Stop) != 0 {
model.StopSequences = req.Stop
}

return model
}

// GetChatCompletions implements provider.LLMProvider.
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (openai.ChatCompletionResponse, error) {
model := p.generativeModel(req)

chat := model.StartChat()

parts := convertPart(chat, req, model, md)

resp, err := chat.SendMessage(ctx, parts...)
if err != nil {
return openai.ChatCompletionResponse{}, err
}

return convertToResponse(resp, p.model), nil
}

// GetChatCompletionsStream implements provider.LLMProvider.
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (provider.ResponseRecver, error) {
model := p.generativeModel(req)

chat := model.StartChat()

parts := convertPart(chat, req, model, md)

resp := chat.SendMessageStream(ctx, parts...)

includeUsage := false
if req.StreamOptions != nil && req.StreamOptions.IncludeUsage {
includeUsage = true
}

recver := &recver{
model: p.model,
underlying: resp,
includeUsage: includeUsage,
}

return recver, nil
}

// Name implements provider.LLMProvider.
func (p *Provider) Name() string {
return "vertexai"
}

type recver struct {
done bool
id string
includeUsage bool
usage *openai.Usage
model string
underlying *genai.GenerateContentResponseIterator
}

// Recv implements provider.ResponseRecver.
func (r *recver) Recv() (response openai.ChatCompletionStreamResponse, err error) {
if r.done {
return openai.ChatCompletionStreamResponse{}, io.EOF
}
if r.id == "" {
r.id = "chatcmpl-" + id.New(29)
}
if r.usage == nil {
r.usage = &openai.Usage{}
}
resp, err := r.underlying.Next()
if err == iterator.Done {
r.usage.TotalTokens = r.usage.PromptTokens + r.usage.CompletionTokens
usageResp := openai.ChatCompletionStreamResponse{
ID: r.id,
Model: r.model,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Usage: r.usage,
Choices: make([]openai.ChatCompletionStreamChoice, 0),
}
r.done = true
return usageResp, nil
}
if err != nil {
return openai.ChatCompletionStreamResponse{}, err
}

return convertToStreamResponse(r.id, resp, r.model, r.usage), nil
}
Loading

0 comments on commit e5cd4ff

Please sign in to comment.