Skip to content

Commit

Permalink
add: custom embeddings implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Jan 10, 2025
1 parent 69fd5cf commit c8d48dc
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
6 changes: 3 additions & 3 deletions gemini-vertex-model-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func main() {
if err := validate(ctx); err != nil {
fmt.Printf("{\"error\": \"%s\"}\n", err)
}
os.Exit(0)
os.Exit(1)
}

c, err := configure(ctx)
Expand Down Expand Up @@ -56,7 +56,7 @@ func configure(ctx context.Context) (*genai.Client, error) {
return nil, fmt.Errorf("failed to parse google application credentials json: %w", err)
}

gcreds, err := google.CredentialsFromJSON(ctx, []byte(credsJSON))
gcreds, err := google.CredentialsFromJSON(ctx, []byte(credsJSON), "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to parse google credentials JSON: %w", err)
}
Expand All @@ -77,7 +77,7 @@ func configure(ctx context.Context) (*genai.Client, error) {
if l, ok := creds["location"]; ok {
loc = l.(string)
} else {
pid = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_LOCATION")
loc = os.Getenv("OBOT_GEMINI_VERTEX_MODEL_PROVIDER_GOOGLE_CLOUD_LOCATION")
}
if loc == "" {
return nil, fmt.Errorf("google cloud location is required")
Expand Down
119 changes: 119 additions & 0 deletions gemini-vertex-model-provider/server/server.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package server

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"

Expand Down Expand Up @@ -531,6 +533,123 @@ func mapFunctionDefinitionFromOpenAI(funcDef *openai.FunctionDefinition) ([]*gen
return functions, nil
}

// openAIEmbeddingRequest - not (yet) provided by the Chat Completion Client package
type openAIEmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
}

type openAIResponse struct {
Data []openAIResponseData `json:"data"`
}

type openAIResponseData struct {
Embedding []float32 `json:"embedding"`
}

type vertexEmbeddingResponse struct {
Predictions []vertexPrediction `json:"predictions"`
}

type vertexPrediction struct {
Embeddings vertexEmbeddings `json:"embeddings"`
}

type vertexEmbeddings struct {
Values []float32 `json:"values"`
// leaving out what we don't need just yet
}

// embeddings - not (yet) provided by the Google GenAI package
func (s *server) embeddings(w http.ResponseWriter, r *http.Request) {

var er openAIEmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&er); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", s.client.ClientConfig().Location, s.client.ClientConfig().Project, s.client.ClientConfig().Location, er.Model)

payload := map[string]any{
"instances": []map[string]any{
{
"tast_type": "QUESTION_ANSWERING",
"content": er.Input,
"parameters": map[string]any{},
},
},
}

if er.Dimensions != nil {
payload["parameters"] = map[string]any{
"outputDimensionality": *er.Dimensions,
}
}

reqBody, err := json.Marshal(payload)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't marshal request body: %v", err), http.StatusInternalServerError)
return
}

req, err := http.NewRequestWithContext(r.Context(), "POST", url, bytes.NewBuffer(reqBody))
if err != nil {
http.Error(w, fmt.Sprintf("couldn't create request: %v", err), http.StatusInternalServerError)
return
}

req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")

resp, err := s.client.ClientConfig().HTTPClient.Do(req)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't make request: %v", err), http.StatusInternalServerError)
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
http.Error(w, fmt.Sprintf("unexpected status code: %d", resp.StatusCode), http.StatusInternalServerError)
return
}

body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't read response body: %v", err), http.StatusInternalServerError)
return
}

var embeddingResponse vertexEmbeddingResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't unmarshal response body: %v", err), http.StatusInternalServerError)
return
}

if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 {
http.Error(w, "no embeddings found in the response", http.StatusInternalServerError)
return
}

if len(embeddingResponse.Predictions) > 1 {
fmt.Println("Info: multiple predictions found in the response - using only the first one")
}

oaiResp := openAIResponse{
Data: []openAIResponseData{
{
Embedding: embeddingResponse.Predictions[0].Embeddings.Values,
},
},
}

if err := json.NewEncoder(w).Encode(oaiResp); err != nil {
http.Error(w, fmt.Sprintf("couldn't encode response: %v", err), http.StatusInternalServerError)
return
}

return
}

0 comments on commit c8d48dc

Please sign in to comment.