diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index faf8285..3866fad 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -196,6 +196,11 @@ def evaluate_rag( ) # peft config and wrapping rag_model.attach_pre_trained_peft_layers(retriever_peft_model_path, generator_peft_model_path, device) + # mapping rag retriever and generator model to device + if retriever_peft_model_path is None: + rag_model.retriever_model.eval().to(device) + if generator_peft_model_path is None: + rag_model.generator_model.eval().to(device) unique_passage_dataset, passage_embeddings_array = get_passage_embeddings( processed_datasets, passage_column_name,