From a8476405dbc92a6acb0784094c73f70044843354 Mon Sep 17 00:00:00 2001 From: sagorbrur Date: Thu, 23 Nov 2023 09:50:52 +0600 Subject: [PATCH 1/2] map rag model to device in eval rag --- dalm/eval/eval_rag.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index faf8285..74a40e1 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -196,6 +196,11 @@ def evaluate_rag( ) # peft config and wrapping rag_model.attach_pre_trained_peft_layers(retriever_peft_model_path, generator_peft_model_path, device) + # mapping rag retriever and generator model to device + if retriever_peft_model_path is None: + rag_model.retriever_model.eval().to(device) + if generator_peft_model_path is None: + rag_model.generator_model.eval().to(device) unique_passage_dataset, passage_embeddings_array = get_passage_embeddings( processed_datasets, passage_column_name, From 91178cc532f2ad792825c5412926a022a1655685 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 4 Dec 2023 18:18:27 -0500 Subject: [PATCH 2/2] Appease format step --- dalm/eval/eval_rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index 74a40e1..3866fad 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -198,7 +198,7 @@ def evaluate_rag( rag_model.attach_pre_trained_peft_layers(retriever_peft_model_path, generator_peft_model_path, device) # mapping rag retriever and generator model to device if retriever_peft_model_path is None: - rag_model.retriever_model.eval().to(device) + rag_model.retriever_model.eval().to(device) if generator_peft_model_path is None: rag_model.generator_model.eval().to(device) unique_passage_dataset, passage_embeddings_array = get_passage_embeddings(