Skip to content

Commit

Permalink
Merge pull request #26 from NERC-CEH/unified-embeddings
Browse files Browse the repository at this point in the history
Unified embeddings
  • Loading branch information
matthewcoole authored Dec 20, 2024
2 parents 40f4625 + dcd49c1 commit a4afd7f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ stages:
outs:
- ${files.chunked}
create-embeddings:
cmd: uv run scripts/create_embeddings.py ${files.chunked} ${files.embeddings} -m ${hp.embeddings-model}
cmd: uv run scripts/create_embeddings.py ${files.chunked} ${files.embeddings} -m ${hp.embeddings-model} -u
deps:
- ${files.chunked}
- scripts/create_embeddings.py
Expand Down
1 change: 1 addition & 0 deletions scripts/chunk_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def chunk_metadata_value(
"field": metada_value["field"],
"id": metada_value["id"],
"index": i,
"dataset": metada_value["dataset"],
}
for i in range(len(chunks))
]
Expand Down
18 changes: 15 additions & 3 deletions scripts/create_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,25 @@ def batched(iterable, n, *, strict=False):
yield batch


def main(input_file: str, output_file: str, model_name: str) -> None:
def create_unified_text_to_embed(batch: list) -> list:
return [f"Metadata: Dataset: {chunk['dataset']}\nText: {chunk['chunk']}" for chunk in batch]


def create_texts_to_embed(use_unified_embeddings: bool, batch: list) -> list:
if use_unified_embeddings:
return create_unified_text_to_embed(batch)
else:
return [chunk["chunk"] for chunk in batch]


def main(input_file: str, output_file: str, model_name: str, use_unified_embeddings: bool) -> None:
model = SentenceTransformer(model_name)
with open(input_file) as input, open(output_file, "w") as output:
data = json.load(input)
batches = list(batched(data, 500))
position = 0
for batch in tqdm(batches):
texts = [chunk["chunk"] for chunk in batch]
texts = create_texts_to_embed(use_unified_embeddings, batch)
embeddings = model.encode(texts)
for embedding in embeddings:
data[position]["embedding"] = embedding.tolist()
Expand All @@ -42,5 +53,6 @@ def main(input_file: str, output_file: str, model_name: str) -> None:
parser.add_argument(
"-m", "--model", help="Embedding model to use.", default="all-MiniLM-L6-v2"
)
parser.add_argument("-u", "--unified-embeddings", help="Use unified embeddings.", action="store_true")
args = parser.parse_args()
main(args.input, args.output, args.model)
main(args.input, args.output, args.model, args.unified_embeddings)
1 change: 1 addition & 0 deletions scripts/extract_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def extact_eidc_metadata_fields(
if json_data[field]:
metadata = {}
metadata["id"] = json_data["identifier"]
metadata["dataset"] = json_data["title"]
metadata["field"] = field
metadata["value"] = json_data[field]
metadatas.append(metadata)
Expand Down
15 changes: 8 additions & 7 deletions scripts/fetch_supporting_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
logger = logging.getLogger(__name__)


def extract_ids(metadata_file: str) -> List[str]:
def extract_ids_and_titles(metadata_file: str) -> List[str]:
with open(metadata_file) as f:
json_data = json.load(f)
titles = [dataset["title"] for dataset in json_data["results"]]
ids = [dataset["identifier"] for dataset in json_data["results"]]
return ids
return list(zip(titles, ids))


def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str, str]]:
def get_supporting_docs(datset_title: str, eidc_id: str, user: str, password: str) -> List[Dict[str, str]]:
try:
res = requests.get(
f"https://legilo.eds-infra.ceh.ac.uk/{eidc_id}/documents",
Expand All @@ -27,7 +28,7 @@ def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str
json_data = res.json()
docs = []
for key, val in json_data["success"].items():
docs.append({"id": eidc_id, "field": key, "value": val})
docs.append({"dataset": datset_title, "id": eidc_id, "field": key, "value": val})
return docs
except Exception as e:
logger.error(
Expand All @@ -40,10 +41,10 @@ def main(metadata_file: str, supporting_docs_file: str) -> None:
load_dotenv()
user = os.getenv("username")
password = os.getenv("password")
ids = extract_ids(metadata_file)
ids_and_titles = extract_ids_and_titles(metadata_file)
docs = []
for id in tqdm(ids):
docs.extend(get_supporting_docs(id, user, password))
for id_title in tqdm(ids_and_titles):
docs.extend(get_supporting_docs(id_title[0], id_title[1], user, password))
with open(supporting_docs_file, "w") as f:
json.dump(docs, f, indent=4)

Expand Down

0 comments on commit a4afd7f

Please sign in to comment.