Skip to content

Commit

Permalink
embeddings: refactor EmbedderClient interface to reduce code duplication
Browse files Browse the repository at this point in the history
Starting with the vertex embedder, but this is applicable to others too
  • Loading branch information
eliben committed Nov 19, 2023
1 parent 65725eb commit 01f5f1f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 52 deletions.
45 changes: 45 additions & 0 deletions embeddings/internal/embedderclient/embedderclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package embedderclient

import (
"context"

"github.com/tmc/langchaingo/embeddings"
)

// EmbedderClient is the interface LLM clients implement for embeddings.
type EmbedderClient interface {
CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error)
}

// BatchedEmbed creates embeddings for the given input texts, batching them
// into batches of batchSize if needed.
func BatchedEmbed(ctx context.Context, embedder EmbedderClient, texts []string, batchSize int) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(texts, batchSize)

emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := embedder.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}
// If the size of this batch is 1, don't average/combine the vectors.
if len(texts) == 1 {
emb = append(emb, curTextEmbeddings[0])
continue
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
}
29 changes: 3 additions & 26 deletions embeddings/vertexai/vertexai_palm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/embeddings/internal/embedderclient"
"github.com/tmc/langchaingo/llms/vertexai"
)

Expand All @@ -30,32 +31,8 @@ func NewVertexAIPaLM(opts ...Option) (*VertexAIPaLM, error) {

// EmbedDocuments creates one vector embedding for each of the texts.
func (e VertexAIPaLM) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines)
return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize)
}

// EmbedQuery embeds a single text.
Expand Down
29 changes: 3 additions & 26 deletions embeddings/vertexai/vertexaichat/vertexai_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/embeddings/internal/embedderclient"
"github.com/tmc/langchaingo/llms/vertexai"
)

Expand All @@ -29,32 +30,8 @@ func NewChatVertexAI(opts ...ChatOption) (ChatVertexAI, error) {
}

func (e ChatVertexAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines)
return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize)
}

func (e ChatVertexAI) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
Expand Down

0 comments on commit 01f5f1f

Please sign in to comment.