Skip to content

Commit

Permalink
Batched embeddings creation
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcoole committed Nov 25, 2024
1 parent 9ec131f commit 248b078
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ stages:
outs:
- ${files.chunked}
create-embeddings:
cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings}
cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings} -m ${hp.embeddings-model}
deps:
- ${files.chunked}
- scripts/create_embeddings.py
Expand Down
25 changes: 15 additions & 10 deletions scripts/create_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import gc
import json
from argparse import ArgumentParser
from itertools import batched

import torch
from sentence_transformers import SentenceTransformer
from torch import Tensor
from tqdm import tqdm


def create_embedding(text: str) -> Tensor:
model = SentenceTransformer("all-MiniLM-L6-v2")
return model.encode(text)


def main(input_file: str, output_file: str) -> None:
def main(input_file: str, output_file: str, model_name: str) -> None:
model = SentenceTransformer(model_name)
with open(input_file) as input, open(output_file, "w") as output:
data = json.load(input)
for chunk in tqdm(data):
chunk["embedding"] = create_embedding(chunk["chunk"]).tolist()
batches = list(batched(data, 500))
position = 0
for batch in tqdm(batches):
texts = [chunk["chunk"] for chunk in batch]
embeddings = model.encode(texts)
for embedding in embeddings:
data[position]["embedding"] = embedding.tolist()
position += 1
gc.collect()
torch.cuda.empty_cache()
json.dump(data, output)
Expand All @@ -27,5 +29,8 @@ def main(input_file: str, output_file: str) -> None:
parser = ArgumentParser("prepare_data.py")
parser.add_argument("input", help="The file to be used as input.")
parser.add_argument("output", help="The path to save the processed result.")
parser.add_argument(
"-m", "--model", help="Embedding model to use.", default="all-MiniLM-L6-v2"
)
args = parser.parse_args()
main(args.input, args.output)
main(args.input, args.output, args.model)

1 comment on commit 248b078

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

answer_correctness: 0.48633080434696374
answer_relevancy: 0.4942387151164841
context_precision: 0.524277123560025
context_recall: 0.5129272439322693

Please sign in to comment.