From a1d2aa80ae8dee71ebd5e48988a6d29dbf6ab16a Mon Sep 17 00:00:00 2001 From: mpc Date: Fri, 20 Dec 2024 10:12:34 +0000 Subject: [PATCH] adds unified embeddings using the dataset title to text chunks --- scripts/chunk_data.py | 1 + scripts/create_embeddings.py | 18 +++++++++++++++--- scripts/extract_metadata.py | 1 + scripts/fetch_supporting_docs.py | 15 ++++++++------- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/scripts/chunk_data.py b/scripts/chunk_data.py index d2e70d6..5346f54 100644 --- a/scripts/chunk_data.py +++ b/scripts/chunk_data.py @@ -25,6 +25,7 @@ def chunk_metadata_value( "field": metada_value["field"], "id": metada_value["id"], "index": i, + "dataset": metada_value["dataset"], } for i in range(len(chunks)) ] diff --git a/scripts/create_embeddings.py b/scripts/create_embeddings.py index 9df7d3e..0c10a64 100644 --- a/scripts/create_embeddings.py +++ b/scripts/create_embeddings.py @@ -18,14 +18,25 @@ def batched(iterable, n, *, strict=False): yield batch -def main(input_file: str, output_file: str, model_name: str) -> None: +def create_unified_text_to_embed(batch: list) -> list: + return [f"Metadata: Dataset: {chunk['dataset']}\nText: {chunk['chunk']}" for chunk in batch] + + +def create_texts_to_embed(use_unified_embeddings: bool, batch: list) -> list: + if use_unified_embeddings: + return create_unified_text_to_embed(batch) + else: + return [chunk["chunk"] for chunk in batch] + + +def main(input_file: str, output_file: str, model_name: str, use_unified_embeddings: bool) -> None: model = SentenceTransformer(model_name) with open(input_file) as input, open(output_file, "w") as output: data = json.load(input) batches = list(batched(data, 500)) position = 0 for batch in tqdm(batches): - texts = [chunk["chunk"] for chunk in batch] + texts = create_texts_to_embed(use_unified_embeddings, batch) embeddings = model.encode(texts) for embedding in embeddings: data[position]["embedding"] = embedding.tolist() @@ -42,5 +53,6 @@ def main(input_file: str, output_file: str, model_name: str) -> None: parser.add_argument( "-m", "--model", help="Embedding model to use.", default="all-MiniLM-L6-v2" ) + parser.add_argument("-u", "--unified-embeddings", help="Use unified embeddings.", action="store_true") args = parser.parse_args() - main(args.input, args.output, args.model) + main(args.input, args.output, args.model, args.unified_embeddings) diff --git a/scripts/extract_metadata.py b/scripts/extract_metadata.py index 9bd4c3c..7f374b5 100644 --- a/scripts/extract_metadata.py +++ b/scripts/extract_metadata.py @@ -13,6 +13,7 @@ def extact_eidc_metadata_fields( if json_data[field]: metadata = {} metadata["id"] = json_data["identifier"] + metadata["dataset"] = json_data["title"] metadata["field"] = field metadata["value"] = json_data[field] metadatas.append(metadata) diff --git a/scripts/fetch_supporting_docs.py b/scripts/fetch_supporting_docs.py index d95493b..5be745c 100644 --- a/scripts/fetch_supporting_docs.py +++ b/scripts/fetch_supporting_docs.py @@ -11,14 +11,15 @@ logger = logging.getLogger(__name__) -def extract_ids(metadata_file: str) -> List[str]: +def extract_ids_and_titles(metadata_file: str) -> List[str]: with open(metadata_file) as f: json_data = json.load(f) + titles = [dataset["title"] for dataset in json_data["results"]] ids = [dataset["identifier"] for dataset in json_data["results"]] - return ids + return list(zip(titles, ids)) -def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str, str]]: +def get_supporting_docs(datset_title: str, eidc_id: str, user: str, password: str) -> List[Dict[str, str]]: try: res = requests.get( f"https://legilo.eds-infra.ceh.ac.uk/{eidc_id}/documents", @@ -27,7 +28,7 @@ def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str json_data = res.json() docs = [] for key, val in json_data["success"].items(): - docs.append({"id": eidc_id, "field": key, "value": val}) + docs.append({"dataset": datset_title, "id": eidc_id, "field": key, "value": val}) return docs except Exception as e: logger.error( @@ -40,10 +41,10 @@ def main(metadata_file: str, supporting_docs_file: str) -> None: load_dotenv() user = os.getenv("username") password = os.getenv("password") - ids = extract_ids(metadata_file) + ids_and_titles = extract_ids_and_titles(metadata_file) docs = [] - for id in tqdm(ids): - docs.extend(get_supporting_docs(id, user, password)) + for id_title in tqdm(ids_and_titles): + docs.extend(get_supporting_docs(id_title[0], id_title[1], user, password)) with open(supporting_docs_file, "w") as f: json.dump(docs, f, indent=4)