Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch embeddings #18

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 59 additions & 54 deletions dvc.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
schema: '2.0'
stages:
fetch-metadata:
cmd: python scripts/fetch_eidc_metadata.py data/eidc_metadata.json -s 2000
cmd: python scripts/fetch_eidc_metadata.py data/eidc_metadata.json -s 0
deps:
- path: scripts/fetch_eidc_metadata.py
hash: md5
Expand All @@ -10,8 +10,8 @@ stages:
outs:
- path: data/eidc_metadata.json
hash: md5
md5: 828442e08598fb258894f9d414943330
size: 12247809
md5: fb338ea98ce71bf7f002be952b6db0e1
size: 12275265
prepare:
cmd: python scripts/extract_metadata.py data/eidc_metadata.json data/extracted_metadata.json
deps:
Expand All @@ -33,135 +33,140 @@ stages:
deps:
- path: data/eidc_metadata.json
hash: md5
md5: 828442e08598fb258894f9d414943330
size: 12247809
md5: fb338ea98ce71bf7f002be952b6db0e1
size: 12275265
- path: scripts/extract_metadata.py
hash: md5
md5: e66f21369c5106eaaad4476612c6fb5e
size: 1313
outs:
- path: data/extracted_metadata.json
hash: md5
md5: f6123510b2b337bc8a2b6a7180e54b36
size: 4606527
md5: 9f4fc9cb1e8af8e0f2d1c95b311989fc
size: 4616342
chunk-data:
cmd: python scripts/chunk_data.py -o data/chunked_data.json -c 250 -ol 75 data/extracted_metadata.json
data/supporting-docs.json -m 250
cmd: python scripts/chunk_data.py -o data/chunked_data.json -c 250 -ol 75 data/extracted_metadata.json data/supporting-docs.json -m
0
deps:
- path: data/extracted_metadata.json
hash: md5
md5: f6123510b2b337bc8a2b6a7180e54b36
size: 4606527
md5: 9f4fc9cb1e8af8e0f2d1c95b311989fc
size: 4616342
- path: data/supporting-docs.json
hash: md5
md5: 9144618bb329984fcd622811a7eac3bb
size: 72280322
md5: 0b14da8f2e73dc8e15747f693c0f70ce
size: 72383140
- path: scripts/chunk_data.py
hash: md5
md5: 3ad449140b03e1c2904b22a5b401a12e
size: 2705
outs:
- path: data/chunked_data.json
hash: md5
md5: 39990a40f23e70dc78424ce0bc408983
size: 124484286
md5: b107dfb052c12ea47b04a5176e8bab4a
size: 176342129
create-embeddings:
cmd: python scripts/create_embeddings.py data/chunked_data.json data/embeddings.json
-m all-MiniLM-L6-v2
deps:
- path: data/chunked_data.json
hash: md5
md5: 39990a40f23e70dc78424ce0bc408983
size: 124484286
md5: b107dfb052c12ea47b04a5176e8bab4a
size: 176342129
- path: scripts/create_embeddings.py
hash: md5
md5: fa4627c83a65af2e3ea9b2b749f1b29d
size: 952
md5: 87bd2ed6373552bea229c9f3465fd3db
size: 1594
outs:
- path: data/embeddings.json
hash: md5
md5: 8971ce1f4d4ade1507b9a469656c36d0
size: 1706778742
md5: 68a9de7fcf765be8ae2f4d3ff6537228
size: 3739724900
upload-to-docstore:
cmd: python scripts/upload_to_docstore.py data/embeddings.json -o data/chroma-data
-em all-MiniLM-L6-v2 -c eidc-data
cmd: python scripts/upload_to_docstore.py data/embeddings.json -o data/chroma-data -em
all-MiniLM-L6-v2 -c eidc-data
deps:
- path: data/embeddings.json
hash: md5
md5: 8971ce1f4d4ade1507b9a469656c36d0
size: 1706778742
md5: 68a9de7fcf765be8ae2f4d3ff6537228
size: 3739724900
- path: scripts/upload_to_docstore.py
hash: md5
md5: 645bdeb372bc79fa7a2e3d8a9eac0d4c
size: 2330
md5: 930456cedd43723c1d643ad90c146952
size: 2793
outs:
- path: data/chroma-data
hash: md5
md5: 7d158df1ea32a09783259b756f468666.dir
size: 1126480472
md5: 486d560a81dc951bdd85772996e62f00.dir
size: 1815042692
nfiles: 6
run-rag-pipeline:
cmd: python scripts/run_rag_pipeline.py data/eidc_rag_test_set.csv data/evaluation_data.csv
data/chroma-data -c eidc-data
cmd: python scripts/run_rag_pipeline.py -i data/eidc_rag_testset.csv -o data/evaluation_data.csv -ds
data/chroma-data -c eidc-data -m llama3.1 -p data/pipeline.yml
deps:
- path: data/chroma-data
hash: md5
md5: 7d158df1ea32a09783259b756f468666.dir
size: 1126480472
md5: 486d560a81dc951bdd85772996e62f00.dir
size: 1815042692
nfiles: 6
- path: data/eidc_rag_test_set.csv
- path: data/eidc_rag_testset.csv
hash: md5
md5: f301e759e74ce5e71b50e04993ec8c88
size: 144597
md5: a371d83c5822d256286e80d64d58c3fe
size: 7524
- path: scripts/run_rag_pipeline.py
hash: md5
md5: 0be13da9adedc1c0dad4837523893061
size: 3869
md5: 2d6dc886728d4bd46676ecd1882f1fd1
size: 5838
outs:
- path: data/evaluation_data.csv
hash: md5
md5: 61fc8879585c0385277ebdc8a6b82420
size: 203253
md5: a473732be874c8256f7178ef3f4dc7a9
size: 9576
- path: data/pipeline.yml
hash: md5
md5: 8e3c4e49d4d97f613e83468d010a96e9
size: 3440
generate-testset:
cmd: cp data/synthetic-datasets/eidc_rag_test_set.csv data/
cmd: head -n 101 data/synthetic-datasets/eidc_rag_test_sample.csv > data/eidc_rag_testset.csv
outs:
- path: data/eidc_rag_test_set.csv
- path: data/eidc_rag_testset.csv
hash: md5
md5: f301e759e74ce5e71b50e04993ec8c88
size: 144597
md5: a371d83c5822d256286e80d64d58c3fe
size: 7524
fetch-supporting-docs:
cmd: python scripts/fetch_supporting_docs.py data/eidc_metadata.json data/supporting-docs.json
deps:
- path: data/eidc_metadata.json
hash: md5
md5: 828442e08598fb258894f9d414943330
size: 12247809
md5: fb338ea98ce71bf7f002be952b6db0e1
size: 12275265
- path: scripts/fetch_supporting_docs.py
hash: md5
md5: 02b94a2cc7bff711784cbdec3650b618
size: 1718
outs:
- path: data/supporting-docs.json
hash: md5
md5: 9144618bb329984fcd622811a7eac3bb
size: 72280322
md5: 0b14da8f2e73dc8e15747f693c0f70ce
size: 72383140
evaluate:
cmd: python scripts/evaluate.py data/evaluation_data.csv -m data/metrics.json
-img data/eval.png
deps:
- path: data/evaluation_data.csv
hash: md5
md5: 61fc8879585c0385277ebdc8a6b82420
size: 203253
md5: a473732be874c8256f7178ef3f4dc7a9
size: 9576
- path: scripts/evaluate.py
hash: md5
md5: 4154acf8e74c1d8bcd0b0da72af038e0
size: 2728
outs:
- path: data/eval.png
hash: md5
md5: 2a1630782c103077959097db4e06b7d8
size: 83362
md5: 7bfd424fa4c9a3550d6e9605bb2f6af2
size: 89143
- path: data/metrics.json
hash: md5
md5: dfdf0d0bf1519ccfa78f95263d63c231
size: 285
md5: f768092fe2696328ff4da565e763e743
size: 270
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ stages:
outs:
- ${files.chunked}
create-embeddings:
cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings}
cmd: python scripts/create_embeddings.py ${files.chunked} ${files.embeddings} -m ${hp.embeddings-model}
deps:
- ${files.chunked}
- scripts/create_embeddings.py
Expand Down
31 changes: 23 additions & 8 deletions scripts/create_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
import gc
import json
from argparse import ArgumentParser
from itertools import islice

import torch
from sentence_transformers import SentenceTransformer
from torch import Tensor
from tqdm import tqdm


def create_embedding(text: str) -> Tensor:
model = SentenceTransformer("all-MiniLM-L6-v2")
return model.encode(text)
def batched(iterable, n, *, strict=False):
if n < 1:
raise ValueError('n must be at least one')
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
yield batch


def main(input_file: str, output_file: str) -> None:
def main(input_file: str, output_file: str, model_name: str) -> None:
model = SentenceTransformer(model_name)
with open(input_file) as input, open(output_file, "w") as output:
data = json.load(input)
for chunk in tqdm(data):
chunk["embedding"] = create_embedding(chunk["chunk"]).tolist()
batches = list(batched(data, 500))
position = 0
for batch in tqdm(batches):
texts = [chunk["chunk"] for chunk in batch]
embeddings = model.encode(texts)
for embedding in embeddings:
data[position]["embedding"] = embedding.tolist()
position += 1
gc.collect()
torch.cuda.empty_cache()
json.dump(data, output)
Expand All @@ -27,5 +39,8 @@ def main(input_file: str, output_file: str) -> None:
parser = ArgumentParser("prepare_data.py")
parser.add_argument("input", help="The file to be used as input.")
parser.add_argument("output", help="The path to save the processed result.")
parser.add_argument(
"-m", "--model", help="Embedding model to use.", default="all-MiniLM-L6-v2"
)
args = parser.parse_args()
main(args.input, args.output)
main(args.input, args.output, args.model)
Loading