From 85438ec278009cf34ff4d0eb346c690589827a4b Mon Sep 17 00:00:00 2001 From: lionHC Date: Tue, 18 Jun 2024 15:34:01 +0800 Subject: [PATCH] feat: Added judgment logic to support training with plain text data. --- finetune/dataset.py | 63 ++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/finetune/dataset.py b/finetune/dataset.py index 92807c3..ad946ee 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -43,29 +43,40 @@ def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: - image = Image.open(self.raw_data[i]["image"]).convert("RGB") - ret = preprocess( - image, - self.raw_data[i]["conversations"], - self.tokenizer, - self.transform, - query_nums=self.query_nums, - slice_config=self.slice_config, - llm_type=self.llm_type, - patch_size=self.patch_size, - batch_vision=self.batch_vision, - ) - ret = dict( - input_ids=ret["input_ids"], - position_ids=ret["position_ids"], - labels=ret["target"], - attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), - pixel_values=ret["pixel_values"], - tgt_sizes=ret["tgt_sizes"], - image_bound=ret["image_bound"], - ) - - return ret + if "image" in self.raw_data[i]: + image = Image.open(self.raw_data[i]["image"]).convert("RGB") + ret = preprocess( + image, + self.raw_data[i]["conversations"], + self.tokenizer, + self.transform, + query_nums=self.query_nums, + slice_config=self.slice_config, + llm_type=self.llm_type, + patch_size=self.patch_size, + batch_vision=self.batch_vision, + ) + ret = dict( + input_ids=ret["input_ids"], + position_ids=ret["position_ids"], + labels=ret["target"], + attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), + pixel_values=ret["pixel_values"], + tgt_sizes=ret["tgt_sizes"], + image_bound=ret["image_bound"], + ) + + return ret + else: + # Processing plain text data + ret = conversation_to_ids(self.raw_data[i]["conversations"], self.tokenizer, self.llm_type) + ret = dict( + input_ids=ret["input_ids"], + position_ids=ret["position_ids"], + labels=ret["target"], + attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), + ) + return ret def data_collator(examples, padding_value=0, max_length=2048): def trim_and_pad(seq, batch_first, padding_value): @@ -91,9 +102,9 @@ def trim_and_pad(seq, batch_first, padding_value): batch_first=True, padding_value=padding_value, ) - pixel_values = [example["pixel_values"] for example in examples] - image_bound = [example["image_bound"] for example in examples] - tgt_sizes = [example["tgt_sizes"] for example in examples] + pixel_values = [example["pixel_values"] if "pixel_values" in example else torch.tensor([]) for example in examples] + image_bound = [example["image_bound"] if "image_bound" in example else torch.tensor([]) for example in examples] + tgt_sizes = [example["tgt_sizes"] if "tgt_sizes" in example else torch.tensor([]) for example in examples] return { "input_ids": input_ids, "position_ids": position_ids,