From 189cc048577c7b93b59117f7fe77510bc9d08b4c Mon Sep 17 00:00:00 2001 From: Tony Hung <102973178+a172166@users.noreply.github.com> Date: Mon, 27 Nov 2023 18:33:22 -0500 Subject: [PATCH] move retriever/generator to device --- dalm/eval/eval_rag.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index faf8285..613c53b 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -196,6 +196,11 @@ def evaluate_rag( ) # peft config and wrapping rag_model.attach_pre_trained_peft_layers(retriever_peft_model_path, generator_peft_model_path, device) + + #move retriever and generator to appropriate device + rag_model.retriever_model.eval().to(device) + rag_model.generator_model.eval().to(device) + unique_passage_dataset, passage_embeddings_array = get_passage_embeddings( processed_datasets, passage_column_name,