diff --git a/dalm/datasets/qa_gen/question_answer_generation.py b/dalm/datasets/qa_gen/question_answer_generation.py index a225dbd..d278895 100644 --- a/dalm/datasets/qa_gen/question_answer_generation.py +++ b/dalm/datasets/qa_gen/question_answer_generation.py @@ -185,11 +185,11 @@ def split_dataset( def generate_qa_from_dataset( - dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, max_input_tokens: int, load_in_8bit: bool = True + dataset: Dataset, passage_column_name: str, title_column_name: str, sample_size: int, batch_size: int, max_input_tokens: int, qa_model: str = QA_MODEL, load_in_8bit: bool = True ) -> DatasetDict: logger.info(f"Generating question answer pairs with batch size: {batch_size}") - tokenizer = AutoTokenizer.from_pretrained(QA_MODEL) - model = AutoModelForCausalLM.from_pretrained(QA_MODEL, torch_dtype="auto", device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(qa_model) + model = AutoModelForCausalLM.from_pretrained(qa_model, torch_dtype="auto", device_map="auto") # shuffle data dataset.shuffle(seed=42)