diff --git a/cli/serve.go b/cli/serve.go index 768e3ec2a..36d988f5b 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -38,6 +38,7 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels" "github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama" "github.com/yomorun/yomo/pkg/bridge/ai/provider/openai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider/vertexai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/xai" ) @@ -162,6 +163,13 @@ func registerAIProvider(aiConfig *ai.Config) error { providerpkg.RegisterProvider(anthropic.NewProvider(provider["api_key"], provider["model"])) case "xai": providerpkg.RegisterProvider(xai.NewProvider(provider["api_key"], provider["model"])) + case "vertexai": + providerpkg.RegisterProvider(vertexai.NewProvider( + provider["project_id"], + provider["location"], + provider["model"], + provider["credentials_file"], + )) default: log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name) } diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index 10f5e0906..b7ed0a83e 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -48,3 +48,9 @@ bridge: xai: api_key: model: + + vertexai: + project_id: + location: + model: + credentials_file: diff --git a/go.mod b/go.mod index 72dd9ac0f..ce97167a4 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/yomorun/yomo go 1.21 require ( + cloud.google.com/go/vertexai v0.13.2 github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.4 github.com/briandowns/spinner v1.23.0 github.com/bytecodealliance/wasmtime-go/v9 v9.0.0 @@ -41,10 +42,12 @@ require ( require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/ai v0.8.0 // indirect + cloud.google.com/go/aiplatform v1.68.0 // indirect cloud.google.com/go/auth v0.10.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.5 // indirect cloud.google.com/go/compute/metadata v0.5.2 // indirect - cloud.google.com/go/longrunning v0.6.0 // indirect + cloud.google.com/go/iam v1.2.2 // indirect + cloud.google.com/go/longrunning v0.6.2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -108,6 +111,7 @@ require ( golang.org/x/term v0.26.0 // indirect golang.org/x/text v0.20.0 // indirect golang.org/x/time v0.8.0 // indirect + google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/grpc v1.67.1 // indirect diff --git a/go.sum b/go.sum index ac2b06abe..b1bde699a 100644 --- a/go.sum +++ b/go.sum @@ -6,14 +6,20 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w= cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE= +cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U= +cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= cloud.google.com/go/auth v0.10.2 h1:oKF7rgBfSHdp/kuhXtqU/tNDr0mZqhYbEh+6SiqzkKo= cloud.google.com/go/auth v0.10.2/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.5 h1:2p29+dePqsCHPP1bqDJcKj4qxRyYCcbzKpFyKGt3MTk= cloud.google.com/go/auth/oauth2adapt v0.2.5/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= -cloud.google.com/go/longrunning v0.6.0 h1:mM1ZmaNsQsnb+5n1DNPeL0KwQd9jQRqSqSDEkBZr+aI= -cloud.google.com/go/longrunning v0.6.0/go.mod h1:uHzSZqW89h7/pasCWNYdUpwGz3PcVWhrWupreVPYLts= +cloud.google.com/go/iam v1.2.2 h1:ozUSofHUGf/F4tCNy/mu9tHLTaxZFLOUiKzjcgWHGIA= +cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY= +cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc= +cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI= +cloud.google.com/go/vertexai v0.13.2 h1:dOnvkMDZy3GdKAz8Isd2d6KV3jQpk6CKvYao1SIupuk= +cloud.google.com/go/vertexai v0.13.2/go.mod h1:+nmz1z8AeYILA5QM2yii3CED1PqGknZH1CUNDVatIg4= dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= @@ -428,6 +434,8 @@ google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 h1:KJjNNclfpIkVqrZlTWcgOOaVQ00LdBnoEaRfkUx760s= +google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:mt9/MofW7AWQ+Gy179ChOnvmJatV8YHUmrcedo9CIFI= google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g= google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 h1:XVhgTWWV3kGQlwJHR3upFWZeTsei6Oks1apkZSeonIE= diff --git a/pkg/bridge/ai/provider/provider_test.go b/pkg/bridge/ai/provider/provider_test.go index 349c05d8a..d09d782af 100644 --- a/pkg/bridge/ai/provider/provider_test.go +++ b/pkg/bridge/ai/provider/provider_test.go @@ -28,9 +28,8 @@ func TestProviders(t *testing.T) { assert.Equal(t, p1, p) }) t.Run("name is empty", func(t *testing.T) { - p, err := GetProvider("") + _, err := GetProvider("") assert.NoError(t, err) - assert.Equal(t, p1, p) }) t.Run("not found", func(t *testing.T) { p, err := GetProvider("name-x") diff --git a/pkg/bridge/ai/provider/vertexai/provider.go b/pkg/bridge/ai/provider/vertexai/provider.go new file mode 100644 index 000000000..77dac04da --- /dev/null +++ b/pkg/bridge/ai/provider/vertexai/provider.go @@ -0,0 +1,142 @@ +// Package vertexai is used to provide the vertexai service +package vertexai + +import ( + "context" + "io" + "log" + "time" + + "cloud.google.com/go/vertexai/genai" + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/id" + "google.golang.org/api/iterator" + "google.golang.org/api/option" +) + +// Provider is the provider for google vertexai. +type Provider struct { + model string + client *genai.Client +} + +var _ provider.LLMProvider = &Provider{} + +// NewProvider creates a new vertexai provider. +func NewProvider(projectID, location, model, credentialsFile string) *Provider { + client, err := genai.NewClient(context.Background(), projectID, location, option.WithCredentialsFile(credentialsFile)) + if err != nil { + log.Fatal("new vertexai client: ", err) + } + if model == "" { + model = "gemini-1.5-pro-002" + } + + return &Provider{ + model: model, + client: client, + } +} + +func (p *Provider) generativeModel(req openai.ChatCompletionRequest) *genai.GenerativeModel { + model := p.client.GenerativeModel(p.model) + + model.SetTemperature(req.Temperature) + model.SetTopP(req.TopP) + if req.MaxTokens > 0 { + model.SetMaxOutputTokens(int32(req.MaxTokens)) + } + if len(req.Stop) != 0 { + model.StopSequences = req.Stop + } + + return model +} + +// GetChatCompletions implements provider.LLMProvider. +func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (openai.ChatCompletionResponse, error) { + model := p.generativeModel(req) + + chat := model.StartChat() + + parts := convertPart(chat, req, model, md) + + resp, err := chat.SendMessage(ctx, parts...) + if err != nil { + return openai.ChatCompletionResponse{}, err + } + + return convertToResponse(resp, p.model), nil +} + +// GetChatCompletionsStream implements provider.LLMProvider. +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, md metadata.M) (provider.ResponseRecver, error) { + model := p.generativeModel(req) + + chat := model.StartChat() + + parts := convertPart(chat, req, model, md) + + resp := chat.SendMessageStream(ctx, parts...) + + includeUsage := false + if req.StreamOptions != nil && req.StreamOptions.IncludeUsage { + includeUsage = true + } + + recver := &recver{ + model: p.model, + underlying: resp, + includeUsage: includeUsage, + } + + return recver, nil +} + +// Name implements provider.LLMProvider. +func (p *Provider) Name() string { + return "vertexai" +} + +type recver struct { + done bool + id string + includeUsage bool + usage *openai.Usage + model string + underlying *genai.GenerateContentResponseIterator +} + +// Recv implements provider.ResponseRecver. +func (r *recver) Recv() (response openai.ChatCompletionStreamResponse, err error) { + if r.done { + return openai.ChatCompletionStreamResponse{}, io.EOF + } + if r.id == "" { + r.id = "chatcmpl-" + id.New(29) + } + if r.usage == nil { + r.usage = &openai.Usage{} + } + resp, err := r.underlying.Next() + if err == iterator.Done { + r.usage.TotalTokens = r.usage.PromptTokens + r.usage.CompletionTokens + usageResp := openai.ChatCompletionStreamResponse{ + ID: r.id, + Model: r.model, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Usage: r.usage, + Choices: make([]openai.ChatCompletionStreamChoice, 0), + } + r.done = true + return usageResp, nil + } + if err != nil { + return openai.ChatCompletionStreamResponse{}, err + } + + return convertToStreamResponse(r.id, resp, r.model, r.usage), nil +} diff --git a/pkg/bridge/ai/provider/vertexai/request_convert.go b/pkg/bridge/ai/provider/vertexai/request_convert.go new file mode 100644 index 000000000..23e43d70e --- /dev/null +++ b/pkg/bridge/ai/provider/vertexai/request_convert.go @@ -0,0 +1,169 @@ +package vertexai + +import ( + "encoding/json" + "fmt" + "strings" + + "cloud.google.com/go/vertexai/genai" + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/metadata" +) + +func convertPart(chat *genai.ChatSession, req openai.ChatCompletionRequest, model *genai.GenerativeModel, md metadata.M) []genai.Part { + parts := []genai.Part{} + history := []*genai.Content{} + + if len(req.Tools) > 0 { + tools := convertTools(req.Tools) + model.Tools = tools + data, _ := json.Marshal(tools) + md.Set("tools", string(data)) + } else { + if data, ok := md.Get("tools"); ok { + var tools []*genai.Tool + _ = json.Unmarshal([]byte(data), &tools) + model.Tools = tools + } + } + + isHistory := false + for i := len(req.Messages) - 1; i >= 0; i-- { + message := req.Messages[i] + + switch message.Role { + case openai.ChatMessageRoleUser: + part := genai.Text(message.Content) + if isHistory { + history = prepend(history, genai.NewUserContent(part)) + } else { + parts = prepend[genai.Part](parts, part) + } + + case openai.ChatMessageRoleSystem: + if message.Content != "" { + model.SystemInstruction = &genai.Content{Parts: []genai.Part{genai.Text(message.Content)}} + } + case openai.ChatMessageRoleAssistant: + if message.Content != "" { + isHistory = true + history = prepend(history, &genai.Content{ + Role: "model", + Parts: []genai.Part{genai.Text(message.Content)}, + }) + } + if len(message.ToolCalls) == 0 { + continue + } + fcParts := []genai.Part{} + for _, tc := range message.ToolCalls { + args := map[string]any{} + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + fcParts = append(fcParts, genai.FunctionCall{ + Name: tc.Function.Name, + Args: args, + }) + } + history = append(history, &genai.Content{ + Role: "model", + Parts: fcParts, + }) + + case openai.ChatMessageRoleTool: + resp := map[string]any{} + if err := json.Unmarshal([]byte(message.Content), &resp); err != nil { + resp["result"] = message.Content + } + + sl := strings.Split(message.ToolCallID, "-") + if len(sl) > 1 { + name := sl[0] + parts = prepend[genai.Part](parts, genai.FunctionResponse{ + Name: name, + Response: resp, + }) + } + } + } + + chat.History = history + return parts +} + +func prepend[T any](parts []T, part T) []T { + return append([]T{part}, parts...) +} + +func convertTools(tools []openai.Tool) []*genai.Tool { + var result []*genai.Tool + + for _, tool := range tools { + params := &ai.FunctionParameters{} + + raw, _ := json.Marshal(tool.Function.Parameters) + _ = json.Unmarshal(raw, params) + + item := &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: convertFunctionParameters(params), + }}, + } + result = append(result, item) + } + + return result +} + +func convertFunctionParameters(params *ai.FunctionParameters) *genai.Schema { + genaiSchema := &genai.Schema{ + Type: genai.TypeObject, + Required: params.Required, + Properties: make(map[string]*genai.Schema, len(params.Properties)), + } + + for k, v := range params.Properties { + genaiSchema.Properties[k] = convertProperty(v) + } + + return genaiSchema +} + +// convertType converts jsonschema type to gemini type +// https://datatracker.ietf.org/doc/html/draft-bhutton-json-schema-validation-00#section-6.1.1 +func convertType(t string) genai.Type { + tt, ok := typeMap[t] + if !ok { + return genai.TypeUnspecified + } + return tt +} + +var typeMap = map[string]genai.Type{ + "string": genai.TypeString, + "integer": genai.TypeInteger, + "number": genai.TypeNumber, + "boolean": genai.TypeBoolean, + "array": genai.TypeArray, + "object": genai.TypeObject, + "null": genai.TypeUnspecified, +} + +func convertProperty(prop *ai.ParameterProperty) *genai.Schema { + enums := []string{} + for _, v := range prop.Enum { + switch v := v.(type) { + case string: + enums = append(enums, v) + default: + enums = append(enums, fmt.Sprintf("%v", v)) + } + } + return &genai.Schema{ + Type: convertType(prop.Type), + Description: prop.Description, + Enum: enums, + } +} diff --git a/pkg/bridge/ai/provider/vertexai/response_convert.go b/pkg/bridge/ai/provider/vertexai/response_convert.go new file mode 100644 index 000000000..39252a0fa --- /dev/null +++ b/pkg/bridge/ai/provider/vertexai/response_convert.go @@ -0,0 +1,141 @@ +package vertexai + +import ( + "encoding/json" + "fmt" + "time" + + "cloud.google.com/go/vertexai/genai" + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/pkg/id" +) + +func convertToResponse(in *genai.GenerateContentResponse, model string) (out openai.ChatCompletionResponse) { + out = openai.ChatCompletionResponse{ + ID: "chatcmpl-" + id.New(29), + Model: model, + Object: "chat.completion", + Created: time.Now().Unix(), + Choices: make([]openai.ChatCompletionChoice, 0), + } + + if in.UsageMetadata != nil { + out.Usage = openai.Usage{ + PromptTokens: int(in.UsageMetadata.PromptTokenCount), + CompletionTokens: int(in.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(in.UsageMetadata.TotalTokenCount), + } + } + + count := 0 + toolCalls := make([]openai.ToolCall, 0) + for _, candidate := range in.Candidates { + for _, part := range candidate.Content.Parts { + index := count + switch pp := part.(type) { + case genai.Text: + out.Choices = append(out.Choices, openai.ChatCompletionChoice{ + Index: int(index), + Message: openai.ChatCompletionMessage{ + Content: string(pp), + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: toOpenAIFinishReason(candidate.FinishReason), + }) + case genai.FunctionCall: + args, _ := json.Marshal(pp.Args) + toolCalls = append(toolCalls, openai.ToolCall{ + Index: genai.Ptr(int(index)), + ID: fmt.Sprintf("%s-%d", pp.Name, index), + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{Name: pp.Name, Arguments: string(args)}, + }) + } + count++ + } + } + + if len(toolCalls) > 0 { + out.Choices = append(out.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + ToolCalls: toolCalls, + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: openai.FinishReasonToolCalls, + }) + } + + return +} + +func convertToStreamResponse(id string, in *genai.GenerateContentResponse, model string, usage *openai.Usage) openai.ChatCompletionStreamResponse { + out := openai.ChatCompletionStreamResponse{ + ID: id, + Model: model, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Choices: make([]openai.ChatCompletionStreamChoice, 0), + } + + if in.UsageMetadata != nil { + usage.PromptTokens = int(in.UsageMetadata.PromptTokenCount) + usage.CompletionTokens += int(in.UsageMetadata.CandidatesTokenCount) + } + + count := 0 + toolCalls := make([]openai.ToolCall, 0) + + for _, candidate := range in.Candidates { + parts := candidate.Content.Parts + for _, part := range parts { + index := count + switch pp := part.(type) { + case genai.Text: + out.Choices = append(out.Choices, openai.ChatCompletionStreamChoice{ + Index: index, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: string(pp), + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: toOpenAIFinishReason(candidate.FinishReason), + }) + case genai.FunctionCall: + args, _ := json.Marshal(pp.Args) + + toolCalls = append(toolCalls, openai.ToolCall{ + Index: genai.Ptr(int(index)), + ID: fmt.Sprintf("%s-%d", pp.Name, index), + Type: openai.ToolTypeFunction, + Function: openai.FunctionCall{Name: pp.Name, Arguments: string(args)}, + }) + } + count++ + } + } + if len(toolCalls) > 0 { + out.Choices = append(out.Choices, openai.ChatCompletionStreamChoice{ + Delta: openai.ChatCompletionStreamChoiceDelta{ + ToolCalls: toolCalls, + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: openai.FinishReasonToolCalls, + }) + } + + return out +} + +var mapFinishReason = map[genai.FinishReason]openai.FinishReason{ + genai.FinishReasonUnspecified: openai.FinishReasonNull, + genai.FinishReasonStop: openai.FinishReasonStop, + genai.FinishReasonMaxTokens: openai.FinishReasonLength, + genai.FinishReasonSafety: openai.FinishReasonContentFilter, +} + +func toOpenAIFinishReason(reason genai.FinishReason) openai.FinishReason { + val, ok := mapFinishReason[reason] + if ok { + return val + } + return openai.FinishReason(fmt.Sprintf("FinishReason(%s)", val)) +}