diff --git a/embeddings/internal/embedderclient/embedderclient.go b/embeddings/internal/embedderclient/embedderclient.go new file mode 100644 index 000000000..4db5c34fa --- /dev/null +++ b/embeddings/internal/embedderclient/embedderclient.go @@ -0,0 +1,45 @@ +package embedderclient + +import ( + "context" + + "github.com/tmc/langchaingo/embeddings" +) + +// EmbedderClient is the interface LLM clients implement for embeddings. +type EmbedderClient interface { + CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) +} + +// BatchedEmbed creates embeddings for the given input texts, batching them +// into batches of batchSize if needed. +func BatchedEmbed(ctx context.Context, embedder EmbedderClient, texts []string, batchSize int) ([][]float32, error) { + batchedTexts := embeddings.BatchTexts(texts, batchSize) + + emb := make([][]float32, 0, len(texts)) + for _, texts := range batchedTexts { + curTextEmbeddings, err := embedder.CreateEmbedding(ctx, texts) + if err != nil { + return nil, err + } + // If the size of this batch is 1, don't average/combine the vectors. + if len(texts) == 1 { + emb = append(emb, curTextEmbeddings[0]) + continue + } + + textLengths := make([]int, 0, len(texts)) + for _, text := range texts { + textLengths = append(textLengths, len(text)) + } + + combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) + if err != nil { + return nil, err + } + + emb = append(emb, combined) + } + + return emb, nil +} diff --git a/embeddings/vertexai/vertexai_palm.go b/embeddings/vertexai/vertexai_palm.go index 126b58183..ff3d1b2da 100644 --- a/embeddings/vertexai/vertexai_palm.go +++ b/embeddings/vertexai/vertexai_palm.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/embeddings/internal/embedderclient" "github.com/tmc/langchaingo/llms/vertexai" ) @@ -30,32 +31,8 @@ func NewVertexAIPaLM(opts ...Option) (*VertexAIPaLM, error) { // EmbedDocuments creates one vector embedding for each of the texts. func (e VertexAIPaLM) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { - batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), - e.BatchSize, - ) - - emb := make([][]float32, 0, len(texts)) - for _, texts := range batchedTexts { - curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) - if err != nil { - return nil, err - } - - textLengths := make([]int, 0, len(texts)) - for _, text := range texts { - textLengths = append(textLengths, len(text)) - } - - combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) - if err != nil { - return nil, err - } - - emb = append(emb, combined) - } - - return emb, nil + texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines) + return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize) } // EmbedQuery embeds a single text. diff --git a/embeddings/vertexai/vertexaichat/vertexai_chat.go b/embeddings/vertexai/vertexaichat/vertexai_chat.go index 8eef1ed65..bc4c5575e 100644 --- a/embeddings/vertexai/vertexaichat/vertexai_chat.go +++ b/embeddings/vertexai/vertexaichat/vertexai_chat.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/embeddings/internal/embedderclient" "github.com/tmc/langchaingo/llms/vertexai" ) @@ -29,32 +30,8 @@ func NewChatVertexAI(opts ...ChatOption) (ChatVertexAI, error) { } func (e ChatVertexAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { - batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), - e.BatchSize, - ) - - emb := make([][]float32, 0, len(texts)) - for _, texts := range batchedTexts { - curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) - if err != nil { - return nil, err - } - - textLengths := make([]int, 0, len(texts)) - for _, text := range texts { - textLengths = append(textLengths, len(text)) - } - - combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) - if err != nil { - return nil, err - } - - emb = append(emb, combined) - } - - return emb, nil + texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines) + return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize) } func (e ChatVertexAI) EmbedQuery(ctx context.Context, text string) ([]float32, error) {