Skip to content

Commit

Permalink
add a limit to the max_input_length to get rid of the OOM issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
shamanez committed Sep 30, 2024
1 parent c8a7871 commit 4c9c224
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Save the files as CSV. If False, will save them as a dataset directory via [`~Dataset.save_to_disk`]",
)
parser.add_argument(
"--max_input_tokens",
type=int,
default=512,
help="Maximum number of input tokens for the model.",
)
args = parser.parse_args()
return args


def generate_question_answer_pairs(
documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str
documents: dict, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, passage_column_name: str, max_input_tokens: int
) -> dict:

"""Generate question answer pairs"""
Expand Down Expand Up @@ -120,7 +126,7 @@ def generate_question_answer_pairs(
add_generation_prompt=True
)

model_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(model.device)
model_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device)

generated_ids = model.generate(
**model_inputs,
Expand Down Expand Up @@ -179,7 +185,7 @@ def split_dataset(


def generate_qa_from_dataset(
dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, load_in_8bit: bool = True
dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, max_input_tokens: int, load_in_8bit: bool = True
) -> DatasetDict:
logger.info(f"Generating question answer pairs with batch size: {batch_size}")
tokenizer = AutoTokenizer.from_pretrained(QA_MODEL)
Expand All @@ -197,7 +203,7 @@ def generate_qa_from_dataset(
f"Test dataset size: {len(small_dataset_splits['test'])}"
)
qa_gen_map = partial(
generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name
generate_question_answer_pairs, model=model, tokenizer=tokenizer, passage_column_name=passage_column_name, max_input_tokens=max_input_tokens
)
processed_data = small_dataset_splits.map(qa_gen_map, batched=True, batch_size=batch_size)
# Print all questions from the test split before filtering
Expand Down Expand Up @@ -254,9 +260,10 @@ def generate_qa_from_disk(
batch_size: int,
output_dir: str,
as_csv: bool,
max_input_tokens: int,
) -> None:
dataset = _load_dataset_from_path(dataset_path)
qa_gen_data = generate_qa_from_dataset(dataset, passage_column_name, title_column_name, sample_size, batch_size)
qa_gen_data = generate_qa_from_dataset(dataset, passage_column_name, title_column_name, sample_size, batch_size, max_input_tokens)
output_path = Path(output_dir)
output_path.mkdir(exist_ok=True)
for split_name, split_ds in qa_gen_data.items():
Expand All @@ -279,6 +286,7 @@ def main() -> None:
args.batch_size,
args.output_dir,
args.as_csv,
args.max_input_tokens,
)


Expand All @@ -290,5 +298,6 @@ def main() -> None:
--dataset_path=knowledge_dataset.csv \
--batch_size=8 \
--sample_size=50 \
--output_dir=out
--output_dir=out \
--max_input_tokens=512
"""

0 comments on commit 4c9c224

Please sign in to comment.