diff --git a/embeddings/ollama/ollama.go b/embeddings/ollama/ollama.go index 933c38c55..98eabe474 100644 --- a/embeddings/ollama/ollama.go +++ b/embeddings/ollama/ollama.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/embeddings/internal/embedderclient" "github.com/tmc/langchaingo/llms/ollama" ) @@ -30,32 +31,8 @@ func NewOllama(opts ...Option) (Ollama, error) { // EmbedDocuments creates one vector embedding for each of the texts. func (e Ollama) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { - batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), - e.BatchSize, - ) - - emb := make([][]float32, 0, len(texts)) - for _, texts := range batchedTexts { - curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) - if err != nil { - return nil, err - } - - textLengths := make([]int, 0, len(texts)) - for _, text := range texts { - textLengths = append(textLengths, len(text)) - } - - combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) - if err != nil { - return nil, err - } - - emb = append(emb, combined) - } - - return emb, nil + texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines) + return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize) } // EmbedQuery embeds a single text. diff --git a/embeddings/ollama/ollamachat/ollama_chat.go b/embeddings/ollama/ollamachat/ollama_chat.go index db3e5b527..03cc7f4b4 100644 --- a/embeddings/ollama/ollamachat/ollama_chat.go +++ b/embeddings/ollama/ollamachat/ollama_chat.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/embeddings/internal/embedderclient" "github.com/tmc/langchaingo/llms/ollama" ) @@ -29,32 +30,8 @@ func NewChatOllama(opts ...ChatOption) (ChatOllama, error) { } func (e ChatOllama) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { - batchedTexts := embeddings.BatchTexts( - embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), - e.BatchSize, - ) - - emb := make([][]float32, 0, len(texts)) - for _, texts := range batchedTexts { - curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) - if err != nil { - return nil, err - } - - textLengths := make([]int, 0, len(texts)) - for _, text := range texts { - textLengths = append(textLengths, len(text)) - } - - combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) - if err != nil { - return nil, err - } - - emb = append(emb, combined) - } - - return emb, nil + texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines) + return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize) } func (e ChatOllama) EmbedQuery(ctx context.Context, text string) ([]float32, error) {