From f4b94004d52d900741a6a0bfc3323d8baccc39a3 Mon Sep 17 00:00:00 2001 From: Yijie Xu Date: Sun, 6 Oct 2024 16:51:37 +0000 Subject: [PATCH 1/6] [Edit] add support for flash attention 2 in HFCausalLMGenerator, and add llama 3 extra support --- flashrag/generator/generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flashrag/generator/generator.py b/flashrag/generator/generator.py index 8d6c492..103aba9 100644 --- a/flashrag/generator/generator.py +++ b/flashrag/generator/generator.py @@ -297,6 +297,7 @@ def _load_model(self, model=None): self.model_path, torch_dtype="auto", device_map="auto", + attn_implementation="flash_attention_2", trust_remote_code=True, ) else: @@ -307,6 +308,8 @@ def _load_model(self, model=None): ) if "qwen" not in self.model_name: tokenizer.pad_token = tokenizer.eos_token + if "Llama-3" in self.model_name: + tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" return model, tokenizer From 3c731b76d2f405ac6edfe748c17f3d564ec16a7e Mon Sep 17 00:00:00 2001 From: Yijie Xu Date: Sun, 6 Oct 2024 16:54:40 +0000 Subject: [PATCH 2/6] [Edit] add refiner on multi gpu --- flashrag/pipeline/pipeline.py | 152 ++++++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 8 deletions(-) diff --git a/flashrag/pipeline/pipeline.py b/flashrag/pipeline/pipeline.py index 1775980..78bbda5 100644 --- a/flashrag/pipeline/pipeline.py +++ b/flashrag/pipeline/pipeline.py @@ -2,6 +2,12 @@ from flashrag.dataset.utils import split_dataset, merge_dataset from flashrag.utils import get_retriever, get_generator, get_refiner, get_judger from flashrag.prompt import PromptTemplate +import torch +import copy +from torch.multiprocessing import Pool +import os +from flashrag.dataset.dataset import Dataset +from tqdm import tqdm class BasicPipeline: @@ -73,13 +79,17 @@ def __init__(self, config, prompt_template=None, retriever=None, generator=None) def naive_run(self, dataset, do_eval=True, pred_process_fun=None): # direct generation without RAG - input_prompts = [self.prompt_template.get_string(question=q) for q in dataset.question] + input_prompts = [ + self.prompt_template.get_string(question=q) for q in dataset.question + ] dataset.update_output("prompt", input_prompts) pred_answer_list = self.generator.generate(input_prompts) dataset.update_output("pred", pred_answer_list) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset def run(self, dataset, do_eval=True, pred_process_fun=None): @@ -127,11 +137,127 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): pred_answer_list = self.generator.generate(input_prompts) dataset.update_output("pred", pred_answer_list) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) + + return dataset + + def run_with_refiner_on_multicard( + self, dataset, do_eval=True, pred_process_fun=None + ): + input_query = dataset.question + + retrieval_results = self.retriever.batch_search(input_query) + dataset.update_output("retrieval_result", retrieval_results) + + if self.refiner: + if hasattr(self.refiner, 'refiner'): + del self.refiner.refiner + del self.refiner + torch.cuda.empty_cache() + torch.multiprocessing.set_start_method("spawn", force=True) + # 获取GPU列表 + gpu_ids = [int(id.strip()) for id in self.config["gpu_id"].split(",")] + num_gpus = len(gpu_ids) + # 将数据集平均分割 + data_chunks = split_list(dataset.data, num_gpus) + # 准备参数列表 + args_list = [] + for i in range(num_gpus): + chunk_data = data_chunks[i] + gpu_id = gpu_ids[i] + args = (chunk_data, gpu_id, self.config) + args_list.append(args) + # 多进程并行运行refiner + with Pool(processes=num_gpus) as pool: + results = pool.map(process_refiner_chunk, args_list) + # 合并结果,保持顺序 + refine_results = [] + for chunk_result in results: + refine_results.extend(chunk_result) + dataset.update_output("refine_result", refine_results) + # 生成输入提示 + input_prompts = [ + self.prompt_template.get_string(question=q, formatted_reference=r) + for q, r in zip(dataset.question, refine_results) + ] + else: + input_prompts = [ + self.prompt_template.get_string(question=q, retrieval_result=r) + for q, r in zip(dataset.question, dataset.retrieval_result) + ] + + dataset.update_output("prompt", input_prompts) + + if self.use_fid: + print("Use FiD generation") + input_prompts = [] + for item in dataset: + q = item.question + docs = item.retrieval_result + input_prompts.append([q + " " + doc for doc in docs]) + # 删除refiner以释放内存 + if self.config["refiner_name"] is not None: + if "kg" in self.config["refiner_name"].lower(): + self.refiner = get_refiner(self.config, self.retriever, self.generator) + self.generator = self.refiner.generator + else: + self.generator = get_generator(self.config) + + # 生成答案 + pred_answer_list_final = [] + score_answer_list_final = [] + for i in tqdm(range(self.generation_count)): + pred_answer_list = self.generator.generate( + input_prompts, return_scores=True + ) + pred_answer_list_final.append(pred_answer_list[0]) + score_answer_list_final.append(pred_answer_list[1]) + if i == 0: + dataset.update_output("pred", pred_answer_list[0]) + final_preds = [[] for _ in range(len(dataset.data))] + final_scores = [[] for _ in range(len(dataset.data))] + for generated_set in pred_answer_list_final: + for i in range(len(dataset.data)): + final_preds[i].append(generated_set[i]) + for generated_set in score_answer_list_final: + for i in range(len(dataset.data)): + final_scores[i].append(generated_set[i]) + dataset.update_output("preds", final_preds) + dataset.update_output("scores", final_scores) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset +def split_list(lst, n): + """Split list `lst` into `n` approximately equal parts.""" + k, m = divmod(len(lst), n) + return [lst[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] + + +def process_refiner_chunk(args): + chunk_data, gpu_id, config = args + # 设置CUDA设备 + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + torch.cuda.set_device(0) # 在每个进程中,设备索引从0开始 + # 更新配置 + config_copy = copy.deepcopy(config) + config_copy["device"] = torch.device("cuda:0") + # 创建refiner + refiner = get_refiner(config_copy) + # 创建子数据集 + sub_dataset = Dataset(config=config_copy, data=chunk_data) + # 运行refiner + refine_results = refiner.batch_run(sub_dataset) + # 释放refiner资源 + del refiner + torch.cuda.empty_cache() + return refine_results + class ConditionalPipeline(BasicPipeline): def __init__(self, config, prompt_template=None): """ @@ -171,7 +297,9 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): # merge datasets into original format dataset = merge_dataset(dataset_split, judge_result) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset @@ -232,17 +360,25 @@ def run(self, dataset, do_eval=True, pred_process_fun=None): dataset_split = split_dataset(dataset, judge_result) for symbol, symbol_dataset in dataset_split.items(): if symbol == "A": - symbol_dataset = self.norag_pipeline.naive_run(symbol_dataset, do_eval=False) + symbol_dataset = self.norag_pipeline.naive_run( + symbol_dataset, do_eval=False + ) elif symbol == "B": - symbol_dataset = self.single_hop_pipeline.run(symbol_dataset, do_eval=False) + symbol_dataset = self.single_hop_pipeline.run( + symbol_dataset, do_eval=False + ) elif symbol == "C": - symbol_dataset = self.multi_hop_pipeline.run(symbol_dataset, do_eval=False) + symbol_dataset = self.multi_hop_pipeline.run( + symbol_dataset, do_eval=False + ) else: assert False, "Unknown symbol!" # merge datasets into original format dataset = merge_dataset(dataset_split, judge_result) - dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) + dataset = self.evaluate( + dataset, do_eval=do_eval, pred_process_fun=pred_process_fun + ) return dataset From ff621ca4bac0d2b556f1569991846b5c4b1254fb Mon Sep 17 00:00:00 2001 From: Yijie Xu Date: Sun, 6 Oct 2024 17:00:24 +0000 Subject: [PATCH 3/6] [Edit] fix bug for selective-context --- flashrag/refiner/selective_context_compressor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flashrag/refiner/selective_context_compressor.py b/flashrag/refiner/selective_context_compressor.py index e3f12c6..be9dd72 100644 --- a/flashrag/refiner/selective_context_compressor.py +++ b/flashrag/refiner/selective_context_compressor.py @@ -188,6 +188,8 @@ def _unit_info(tokens, self_info, units): unit_self_info = [[] for _ in range(len(units))] for idx, (token, info) in enumerate(zip(tokens, self_info)): + if current_unit_idx >= len(units): + break # Add this check to ensure within the bounds current_position += len(token) if current_position == len(units[current_unit_idx]): unit_self_info[current_unit_idx].append(info) From e8187f4cfdd0d5794a4f0c6102e941b0678c2495 Mon Sep 17 00:00:00 2001 From: Yijie Xu Date: Sun, 6 Oct 2024 17:04:37 +0000 Subject: [PATCH 4/6] [Edit] add more details for method recomp --- flashrag/utils/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/flashrag/utils/utils.py b/flashrag/utils/utils.py index 979f7c8..464f7f3 100644 --- a/flashrag/utils/utils.py +++ b/flashrag/utils/utils.py @@ -86,8 +86,11 @@ def get_refiner(config, retriever=None, generator=None): # 预定义默认路径字典 DEFAULT_PATH_DICT = { "recomp_abstractive_nq": "fangyuan/nq_abstractive_compressor", - "recomp:abstractive_tqa": "fangyuan/tqa_abstractive_compressor", - "recomp:abstractive_hotpotqa": "fangyuan/hotpotqa_abstractive", + "recomp_abstractive_tqa": "fangyuan/tqa_abstractive_compressor", + "recomp_abstractive_hotpotqa": "fangyuan/hotpotqa_abstractive", + "recomp_extractive_nq": "fangyuan/nq_extractive_compressor", + "recomp_extractive_tqa": "fangyuan/tqa_extractive_compressor", + "recomp_extractive_hotpotqa": "fangyuan/hotpotqa_extractive_compressor", } REFINER_MODULE = importlib.import_module("flashrag.refiner") @@ -97,7 +100,8 @@ def get_refiner(config, retriever=None, generator=None): if config["refiner_model_path"] is not None else DEFAULT_PATH_DICT.get(refiner_name, None) ) - + if not config['refiner_model_path']: + config['refiner_model_path'] = refiner_path try: model_config = AutoConfig.from_pretrained(refiner_path) arch = model_config.architectures[0].lower() From 0e5c145604a947eea7bb80ed28b30a7f242a3636 Mon Sep 17 00:00:00 2001 From: Yijie Xu Date: Sun, 6 Oct 2024 17:09:14 +0000 Subject: [PATCH 5/6] [Edit] Fix a bug for intermediate_data.json saving when using bm25s --- flashrag/retriever/retriever.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flashrag/retriever/retriever.py b/flashrag/retriever/retriever.py index 1868b13..6b9c776 100644 --- a/flashrag/retriever/retriever.py +++ b/flashrag/retriever/retriever.py @@ -259,6 +259,7 @@ def _batch_search(self, query_list, num: int = None, return_score=False): elif self.backend == 'bm25s': query_tokens = self.tokenizer.tokenize(query_list) results, scores = self.searcher.retrieve(query_tokens, k=num) + results, scores = results.tolist(), scores.tolist() else: assert False, 'Invalid bm25 backend!' From ff2fb4c4f28703902d088deb85c2895ed99e439c Mon Sep 17 00:00:00 2001 From: Jiajie Jin <55341185+ignorejjj@users.noreply.github.com> Date: Fri, 11 Oct 2024 17:14:39 +0800 Subject: [PATCH 6/6] Update generator.py --- flashrag/generator/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashrag/generator/generator.py b/flashrag/generator/generator.py index 103aba9..27091b8 100644 --- a/flashrag/generator/generator.py +++ b/flashrag/generator/generator.py @@ -308,8 +308,8 @@ def _load_model(self, model=None): ) if "qwen" not in self.model_name: tokenizer.pad_token = tokenizer.eos_token - if "Llama-3" in self.model_name: tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" return model, tokenizer