diff --git a/dvc.yaml b/dvc.yaml index 4dbf7e1..cfe7a72 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -37,7 +37,7 @@ stages: outs: - ${files.chunked} create-embeddings: - cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings} + cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings} -m ${hp.embeddings-model} deps: - ${files.chunked} - scripts/create_embeddings.py diff --git a/scripts/create_embeddings.py b/scripts/create_embeddings.py index 1ae255e..945c79b 100644 --- a/scripts/create_embeddings.py +++ b/scripts/create_embeddings.py @@ -1,23 +1,25 @@ import gc import json from argparse import ArgumentParser +from itertools import batched import torch from sentence_transformers import SentenceTransformer -from torch import Tensor from tqdm import tqdm -def create_embedding(text: str) -> Tensor: - model = SentenceTransformer("all-MiniLM-L6-v2") - return model.encode(text) - - -def main(input_file: str, output_file: str) -> None: +def main(input_file: str, output_file: str, model_name: str) -> None: + model = SentenceTransformer(model_name) with open(input_file) as input, open(output_file, "w") as output: data = json.load(input) - for chunk in tqdm(data): - chunk["embedding"] = create_embedding(chunk["chunk"]).tolist() + batches = list(batched(data, 500)) + position = 0 + for batch in tqdm(batches): + texts = [chunk["chunk"] for chunk in batch] + embeddings = model.encode(texts) + for embedding in embeddings: + data[position]["embedding"] = embedding.tolist() + position += 1 gc.collect() torch.cuda.empty_cache() json.dump(data, output) @@ -27,5 +29,8 @@ def main(input_file: str, output_file: str) -> None: parser = ArgumentParser("prepare_data.py") parser.add_argument("input", help="The file to be used as input.") parser.add_argument("output", help="The path to save the processed result.") + parser.add_argument( + "-m", "--model", help="Embedding model to use.", default="all-MiniLM-L6-v2" + ) args = parser.parse_args() - main(args.input, args.output) + main(args.input, args.output, args.model)