From 4c9c2245229872de9d4c5cdb0b5e25beb5b49f14 Mon Sep 17 00:00:00 2001 From: Shamane Siri Date: Mon, 30 Sep 2024 15:05:04 +1300 Subject: [PATCH] add a limit to the max_input_length to get rid of the OOM issues. --- .../qa_gen/question_answer_generation.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index 455bdfd..a225dbd 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -64,12 +64,18 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Save the files as CSV. If False, will save them as a dataset directory via [`~Dataset.save_to_disk`]", ) + parser.add_argument( + "--max_input_tokens", + type=int, + default=512, + help="Maximum number of input tokens for the model.", + ) args = parser.parse_args() return args def generate_question_answer_pairs( - documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str + documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str, max_input_tokens: int ) -> dict: """Generate question answer pairs""" @@ -120,7 +126,7 @@ def generate_question_answer_pairs( add_generation_prompt=True ) - model_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(model.device) + model_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device) generated_ids = model.generate( **model_inputs, @@ -179,7 +185,7 @@ def split_dataset( def generate_qa_from_dataset( - dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, load_in_8bit: bool = True + dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, max_input_tokens: int, load_in_8bit: bool = True ) -> DatasetDict: logger.info(f"Generating question answer pairs with batch size: {batch_size}") tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) @@ -197,7 +203,7 @@ def generate_qa_from_dataset( f"Test dataset size: {len(small_dataset_splits['test'])}" ) qa_gen_map = partial( - generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name + generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name, max_input_tokens=max_input_tokens ) processed_data = small_dataset_splits.map(qa_gen_map, batched=True, batch_size=batch_size) # Print all questions from the test split before filtering @@ -254,9 +260,10 @@ def generate_qa_from_disk( batch_size: int, output_dir: str, as_csv: bool, + max_input_tokens: int, ) -> None: dataset = _load_dataset_from_path(dataset_path) - qa_gen_data = generate_qa_from_dataset(dataset, passage_column_name, title_column_name, sample_size, batch_size) + qa_gen_data = generate_qa_from_dataset(dataset, passage_column_name, title_column_name, sample_size, batch_size, max_input_tokens) output_path = Path(output_dir) output_path.mkdir(exist_ok=True) for split_name, split_ds in qa_gen_data.items(): @@ -279,6 +286,7 @@ def main() -> None: args.batch_size, args.output_dir, args.as_csv, + args.max_input_tokens, ) @@ -290,5 +298,6 @@ def main() -> None: --dataset_path=knowledge_dataset.csv \ --batch_size=8 \ --sample_size=50 \ - --output_dir=out + --output_dir=out \ + --max_input_tokens=512 """