Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Update] new functions and bug fixes #90

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flashrag/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _load_model(self, model=None):
self.model_path,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
else:
Expand All @@ -307,6 +308,8 @@ def _load_model(self, model=None):
)
if "qwen" not in self.model_name:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = "left"

return model, tokenizer
Expand Down
152 changes: 144 additions & 8 deletions flashrag/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from flashrag.dataset.utils import split_dataset, merge_dataset
from flashrag.utils import get_retriever, get_generator, get_refiner, get_judger
from flashrag.prompt import PromptTemplate
import torch
import copy
from torch.multiprocessing import Pool
import os
from flashrag.dataset.dataset import Dataset
from tqdm import tqdm


class BasicPipeline:
Expand Down Expand Up @@ -73,13 +79,17 @@ def __init__(self, config, prompt_template=None, retriever=None, generator=None)

def naive_run(self, dataset, do_eval=True, pred_process_fun=None):
# direct generation without RAG
input_prompts = [self.prompt_template.get_string(question=q) for q in dataset.question]
input_prompts = [
self.prompt_template.get_string(question=q) for q in dataset.question
]
dataset.update_output("prompt", input_prompts)

pred_answer_list = self.generator.generate(input_prompts)
dataset.update_output("pred", pred_answer_list)

dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
dataset = self.evaluate(
dataset, do_eval=do_eval, pred_process_fun=pred_process_fun
)
return dataset

def run(self, dataset, do_eval=True, pred_process_fun=None):
Expand Down Expand Up @@ -127,11 +137,127 @@ def run(self, dataset, do_eval=True, pred_process_fun=None):
pred_answer_list = self.generator.generate(input_prompts)
dataset.update_output("pred", pred_answer_list)

dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
dataset = self.evaluate(
dataset, do_eval=do_eval, pred_process_fun=pred_process_fun
)

return dataset

def run_with_refiner_on_multicard(
self, dataset, do_eval=True, pred_process_fun=None
):
input_query = dataset.question

retrieval_results = self.retriever.batch_search(input_query)
dataset.update_output("retrieval_result", retrieval_results)

if self.refiner:
if hasattr(self.refiner, 'refiner'):
del self.refiner.refiner
del self.refiner
torch.cuda.empty_cache()
torch.multiprocessing.set_start_method("spawn", force=True)
# 获取GPU列表
gpu_ids = [int(id.strip()) for id in self.config["gpu_id"].split(",")]
num_gpus = len(gpu_ids)
# 将数据集平均分割
data_chunks = split_list(dataset.data, num_gpus)
# 准备参数列表
args_list = []
for i in range(num_gpus):
chunk_data = data_chunks[i]
gpu_id = gpu_ids[i]
args = (chunk_data, gpu_id, self.config)
args_list.append(args)
# 多进程并行运行refiner
with Pool(processes=num_gpus) as pool:
results = pool.map(process_refiner_chunk, args_list)
# 合并结果,保持顺序
refine_results = []
for chunk_result in results:
refine_results.extend(chunk_result)
dataset.update_output("refine_result", refine_results)
# 生成输入提示
input_prompts = [
self.prompt_template.get_string(question=q, formatted_reference=r)
for q, r in zip(dataset.question, refine_results)
]
else:
input_prompts = [
self.prompt_template.get_string(question=q, retrieval_result=r)
for q, r in zip(dataset.question, dataset.retrieval_result)
]

dataset.update_output("prompt", input_prompts)

if self.use_fid:
print("Use FiD generation")
input_prompts = []
for item in dataset:
q = item.question
docs = item.retrieval_result
input_prompts.append([q + " " + doc for doc in docs])
# 删除refiner以释放内存
if self.config["refiner_name"] is not None:
if "kg" in self.config["refiner_name"].lower():
self.refiner = get_refiner(self.config, self.retriever, self.generator)
self.generator = self.refiner.generator
else:
self.generator = get_generator(self.config)

# 生成答案
pred_answer_list_final = []
score_answer_list_final = []
for i in tqdm(range(self.generation_count)):
pred_answer_list = self.generator.generate(
input_prompts, return_scores=True
)
pred_answer_list_final.append(pred_answer_list[0])
score_answer_list_final.append(pred_answer_list[1])
if i == 0:
dataset.update_output("pred", pred_answer_list[0])
final_preds = [[] for _ in range(len(dataset.data))]
final_scores = [[] for _ in range(len(dataset.data))]
for generated_set in pred_answer_list_final:
for i in range(len(dataset.data)):
final_preds[i].append(generated_set[i])
for generated_set in score_answer_list_final:
for i in range(len(dataset.data)):
final_scores[i].append(generated_set[i])
dataset.update_output("preds", final_preds)
dataset.update_output("scores", final_scores)
dataset = self.evaluate(
dataset, do_eval=do_eval, pred_process_fun=pred_process_fun
)

return dataset


def split_list(lst, n):
"""Split list `lst` into `n` approximately equal parts."""
k, m = divmod(len(lst), n)
return [lst[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]


def process_refiner_chunk(args):
chunk_data, gpu_id, config = args
# 设置CUDA设备
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
torch.cuda.set_device(0) # 在每个进程中,设备索引从0开始
# 更新配置
config_copy = copy.deepcopy(config)
config_copy["device"] = torch.device("cuda:0")
# 创建refiner
refiner = get_refiner(config_copy)
# 创建子数据集
sub_dataset = Dataset(config=config_copy, data=chunk_data)
# 运行refiner
refine_results = refiner.batch_run(sub_dataset)
# 释放refiner资源
del refiner
torch.cuda.empty_cache()
return refine_results

class ConditionalPipeline(BasicPipeline):
def __init__(self, config, prompt_template=None):
"""
Expand Down Expand Up @@ -171,7 +297,9 @@ def run(self, dataset, do_eval=True, pred_process_fun=None):
# merge datasets into original format
dataset = merge_dataset(dataset_split, judge_result)

dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
dataset = self.evaluate(
dataset, do_eval=do_eval, pred_process_fun=pred_process_fun
)

return dataset

Expand Down Expand Up @@ -232,17 +360,25 @@ def run(self, dataset, do_eval=True, pred_process_fun=None):
dataset_split = split_dataset(dataset, judge_result)
for symbol, symbol_dataset in dataset_split.items():
if symbol == "A":
symbol_dataset = self.norag_pipeline.naive_run(symbol_dataset, do_eval=False)
symbol_dataset = self.norag_pipeline.naive_run(
symbol_dataset, do_eval=False
)
elif symbol == "B":
symbol_dataset = self.single_hop_pipeline.run(symbol_dataset, do_eval=False)
symbol_dataset = self.single_hop_pipeline.run(
symbol_dataset, do_eval=False
)
elif symbol == "C":
symbol_dataset = self.multi_hop_pipeline.run(symbol_dataset, do_eval=False)
symbol_dataset = self.multi_hop_pipeline.run(
symbol_dataset, do_eval=False
)
else:
assert False, "Unknown symbol!"

# merge datasets into original format
dataset = merge_dataset(dataset_split, judge_result)

dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun)
dataset = self.evaluate(
dataset, do_eval=do_eval, pred_process_fun=pred_process_fun
)

return dataset
2 changes: 2 additions & 0 deletions flashrag/refiner/selective_context_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def _unit_info(tokens, self_info, units):
unit_self_info = [[] for _ in range(len(units))]

for idx, (token, info) in enumerate(zip(tokens, self_info)):
if current_unit_idx >= len(units):
break # Add this check to ensure within the bounds
current_position += len(token)
if current_position == len(units[current_unit_idx]):
unit_self_info[current_unit_idx].append(info)
Expand Down
1 change: 1 addition & 0 deletions flashrag/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _batch_search(self, query_list, num: int = None, return_score=False):
elif self.backend == 'bm25s':
query_tokens = self.tokenizer.tokenize(query_list)
results, scores = self.searcher.retrieve(query_tokens, k=num)
results, scores = results.tolist(), scores.tolist()
else:
assert False, 'Invalid bm25 backend!'

Expand Down
10 changes: 7 additions & 3 deletions flashrag/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ def get_refiner(config, retriever=None, generator=None):
# 预定义默认路径字典
DEFAULT_PATH_DICT = {
"recomp_abstractive_nq": "fangyuan/nq_abstractive_compressor",
"recomp:abstractive_tqa": "fangyuan/tqa_abstractive_compressor",
"recomp:abstractive_hotpotqa": "fangyuan/hotpotqa_abstractive",
"recomp_abstractive_tqa": "fangyuan/tqa_abstractive_compressor",
"recomp_abstractive_hotpotqa": "fangyuan/hotpotqa_abstractive",
"recomp_extractive_nq": "fangyuan/nq_extractive_compressor",
"recomp_extractive_tqa": "fangyuan/tqa_extractive_compressor",
"recomp_extractive_hotpotqa": "fangyuan/hotpotqa_extractive_compressor",
}
REFINER_MODULE = importlib.import_module("flashrag.refiner")

Expand All @@ -97,7 +100,8 @@ def get_refiner(config, retriever=None, generator=None):
if config["refiner_model_path"] is not None
else DEFAULT_PATH_DICT.get(refiner_name, None)
)

if not config['refiner_model_path']:
config['refiner_model_path'] = refiner_path
try:
model_config = AutoConfig.from_pretrained(refiner_path)
arch = model_config.architectures[0].lower()
Expand Down